题目链接
题目描述
为了追踪突发热点,需要在“查询时刻 之前的最近
篇文档”内,根据加权余弦相似度挑选最相关的文档。
具体规则如下:
- 查询窗口:对于查询时刻
,窗口为
min(t, N-1)
号文档及其之前的总共篇文档。
- TF-IDF 词向量:
- 词频 TF(
) 是词语
在文档
中出现的次数。
- 逆文档频率 IDF(
) 计算采用平滑公式:
,其中
是窗口内的文档总数(即
),
是窗口内包含词语
的文档数。
- 词频 TF(
- 加权余弦相似度:
- 查询
与文档
的原始余弦相似度为
。
- 窗口内从旧到新(即文档编号从小到大)的第
篇文档(
),其时间权重为
。
- 最终相似度 = 原始余弦相似度
时间权重。
- 查询
- 筛选规则:
- 找出相似度
且最高的文档。
- 若存在多个最高分,返回窗口中最早的(即编号最小的)文档。
- 若没有满足条件的文档,输出 -1。
- 找出相似度
输入:
- 文档总数
行文档内容
- 窗口大小
- 查询总数
行查询,每行格式为 “
”
输出:
个数字,表示每次查询的结果。
解题思路
本题是一道复杂的模拟题,核心是为每次查询动态计算窗口内文档的 TF-IDF 向量,并据此计算加权余弦相似度。需要严格按照题目定义的公式和流程进行计算。
对于每一次查询 ,算法步骤如下:
-
确定查询窗口: 根据题目描述和示例推断,查询时刻
对应的窗口是文档编号从
end_idx - K + 1
到end_idx
的文档,其中end_idx = min(t, N-1)
。 -
处理查询和文档: 将查询短语
和窗口内的
篇文档都进行分词,并统计每个词的词频(TF)。使用
map<string, int>
来存储词频。 -
构建词汇表和计算 DF: 遍历查询
和窗口内所有文档中的每一个词,构建当前查询的词汇表。同时,计算词汇表中每个词的文档频率
,即它在窗口内多少个文档中出现过。
-
计算 TF-IDF 向量: 对于词汇表中的每一个词
:
- 计算其 IDF 值:
。
- 查询向量
在
维度上的分量为
。
- 窗口内每个文档
在
维度上的分量为
。
- 将这些分量存储在
map<string, double>
结构的向量中。
- 计算其 IDF 值:
-
计算加权余弦相似度: 遍历窗口内的每一篇文档
(其中
从
start_idx
到end_idx
):- 计算查询向量
和文档向量
的点积
。
- 计算两个向量的 L2 范数(模长)
和
。
- 计算原始余弦相似度:
。如果分母为 0,则相似度为 0。
- 计算时间权重。窗口中第
篇文档(
)的文档编号为
start_idx + j - 1
。其权重为。
- 最终相似度
。
- 计算查询向量
-
筛选最佳文档:
- 维护一个变量
max_sim
记录最高相似度(初始化为 -1.0),best_doc_id
记录最佳文档编号(初始化为 -1)。 - 遍历计算出的每个文档的最终相似度
:
- 如果
且
,则更新
max_sim = S_final
和best_doc_id = i
。
- 如果
- 根据规则,若相似度并列,返回编号最小的文档。由于我们是按文档编号从小到大的顺序遍历的,所以只有在严格大于当前最大相似度时才更新,天然地满足了这一要求。
- 维护一个变量
-
输出结果: 完成对一个查询的处理后,输出
best_doc_id
。对所有次查询重复以上步骤。
代码
#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <map>
#include <cmath>
#include <algorithm>
#include <iomanip>
#include <set>
using namespace std;
// 分词并计算词频
map<string, int> get_tf(const string& text) {
map<string, int> tf;
stringstream ss(text);
string word;
while (ss >> word) {
tf[word]++;
}
return tf;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int n_total;
cin >> n_total;
cin.ignore();
vector<string> docs(n_total);
for (int i = 0; i < n_total; ++i) {
getline(cin, docs[i]);
}
int k;
cin >> k;
int p;
cin >> p;
cin.ignore();
vector<int> results;
for (int i = 0; i < p; ++i) {
string line;
getline(cin, line);
stringstream ss(line);
int t;
ss >> t;
string q_text;
getline(ss, q_text);
// 移除前导空格
if (!q_text.empty() && q_text[0] == ' ') {
q_text = q_text.substr(1);
}
int end_idx = min(t, n_total - 1);
int start_idx = end_idx - k + 1;
vector<map<string, int>> window_docs_tf;
for (int j = start_idx; j <= end_idx; ++j) {
window_docs_tf.push_back(get_tf(docs[j]));
}
map<string, int> q_tf = get_tf(q_text);
set<string> vocab;
for (const auto& pair : q_tf) {
vocab.insert(pair.first);
}
for (const auto& doc_tf : window_docs_tf) {
for (const auto& pair : doc_tf) {
vocab.insert(pair.first);
}
}
map<string, int> df;
for (const string& word : vocab) {
df[word] = 0;
for (const auto& doc_tf : window_docs_tf) {
if (doc_tf.count(word)) {
df[word]++;
}
}
}
map<string, double> idf;
for (const string& word : vocab) {
idf[word] = log((double)(k + 1) / (df[word] + 1)) + 1.0;
}
map<string, double> q_vec;
double q_norm = 0.0;
for (const string& word : vocab) {
double tf_val = q_tf.count(word) ? q_tf[word] : 0;
q_vec[word] = tf_val * idf[word];
q_norm += q_vec[word] * q_vec[word];
}
q_norm = sqrt(q_norm);
double max_sim = -1.0;
int best_doc_id = -1;
double best_doc_base_sim = -1.0;
for (int j = 0; j < k; ++j) {
int doc_idx = start_idx + j;
map<string, double> doc_vec;
double doc_norm = 0.0;
double dot_product = 0.0;
for (const string& word : vocab) {
double tf_val = window_docs_tf[j].count(word) ? window_docs_tf[j][word] : 0;
doc_vec[word] = tf_val * idf[word];
doc_norm += doc_vec[word] * doc_vec[word];
dot_product += q_vec[word] * doc_vec[word];
}
doc_norm = sqrt(doc_norm);
double cos_sim = 0.0;
if (q_norm > 1e-9 && doc_norm > 1e-9) {
cos_sim = dot_product / (q_norm * doc_norm);
}
double weight = (double)(j + 1) / k;
double weighted_sim = cos_sim * weight;
if (weighted_sim > max_sim) {
max_sim = weighted_sim;
best_doc_id = doc_idx;
best_doc_base_sim = cos_sim;
}
}
results.push_back(best_doc_base_sim >= 0.6 ? best_doc_id : -1);
}
for (int i = 0; i < results.size(); ++i) {
cout << results[i] << (i == results.size() - 1 ? "" : " ");
}
cout << endl;
return 0;
}
import java.util.*;
import java.io.*;
import java.lang.Math;
public class Main {
// 分词并计算词频
private static Map<String, Integer> getTf(String text) {
Map<String, Integer> tf = new HashMap<>();
String[] words = text.split("\\s+");
for (String word : words) {
if (!word.isEmpty()) {
tf.put(word, tf.getOrDefault(word, 0) + 1);
}
}
return tf;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int nTotal = sc.nextInt();
sc.nextLine();
List<String> docs = new ArrayList<>();
for (int i = 0; i < nTotal; ++i) {
docs.add(sc.nextLine());
}
int k = sc.nextInt();
int p = sc.nextInt();
sc.nextLine();
List<Integer> results = new ArrayList<>();
for (int i = 0; i < p; ++i) {
int t = sc.nextInt();
String qText = sc.nextLine().trim();
int endIdx = Math.min(t, nTotal - 1);
int startIdx = endIdx - k + 1;
List<Map<String, Integer>> windowDocsTf = new ArrayList<>();
for (int j = startIdx; j <= endIdx; ++j) {
windowDocsTf.add(getTf(docs.get(j)));
}
Map<String, Integer> qTf = getTf(qText);
Set<String> vocab = new HashSet<>();
vocab.addAll(qTf.keySet());
for (Map<String, Integer> docTf : windowDocsTf) {
vocab.addAll(docTf.keySet());
}
Map<String, Integer> df = new HashMap<>();
for (String word : vocab) {
df.put(word, 0);
for (Map<String, Integer> docTf : windowDocsTf) {
if (docTf.containsKey(word)) {
df.put(word, df.get(word) + 1);
}
}
}
Map<String, Double> idf = new HashMap<>();
for (String word : vocab) {
idf.put(word, Math.log((double)(k + 1) / (df.get(word) + 1)) + 1.0);
}
Map<String, Double> qVec = new HashMap<>();
double qNorm = 0.0;
for (String word : vocab) {
double tfVal = qTf.getOrDefault(word, 0);
double tfIdfVal = tfVal * idf.get(word);
qVec.put(word, tfIdfVal);
qNorm += tfIdfVal * tfIdfVal;
}
qNorm = Math.sqrt(qNorm);
double maxSim = -1.0;
int bestDocId = -1;
double bestDocBaseSim = -1.0;
for (int j = 0; j < k; ++j) {
int docIdx = startIdx + j;
Map<String, Integer> currentDocTf = windowDocsTf.get(j);
double docNorm = 0.0;
double dotProduct = 0.0;
for (String word : vocab) {
double tfVal = currentDocTf.getOrDefault(word, 0);
double tfIdfVal = tfVal * idf.get(word);
docNorm += tfIdfVal * tfIdfVal;
dotProduct += qVec.get(word) * tfIdfVal;
}
docNorm = Math.sqrt(docNorm);
double cosSim = 0.0;
if (qNorm > 1e-9 && docNorm > 1e-9) {
cosSim = dotProduct / (qNorm * docNorm);
}
double weight = (double)(j + 1) / k;
double weightedSim = cosSim * weight;
if (weightedSim > maxSim) {
maxSim = weightedSim;
bestDocId = docIdx;
bestDocBaseSim = cosSim;
}
}
results.add(bestDocBaseSim >= 0.6 ? bestDocId : -1);
}
for (int i = 0; i < results.size(); ++i) {
System.out.print(results.get(i) + (i == results.size() - 1 ? "" : " "));
}
System.out.println();
}
}
import math
from collections import Counter
def get_tf(text):
return Counter(text.split())
def main():
n_total = int(input())
docs = [input() for _ in range(n_total)]
k = int(input())
p = int(input())
results = []
for _ in range(p):
line = input().split(maxsplit=1)
t = int(line[0])
q_text = line[1] if len(line) > 1 else ""
end_idx = min(t, n_total - 1)
start_idx = end_idx - k + 1
window_docs = docs[start_idx : end_idx + 1]
window_docs_tf = [get_tf(doc) for doc in window_docs]
q_tf = get_tf(q_text)
vocab = set(q_tf.keys())
for doc_tf in window_docs_tf:
vocab.update(doc_tf.keys())
df = {word: 0 for word in vocab}
for word in vocab:
for doc_tf in window_docs_tf:
if word in doc_tf:
df[word] += 1
idf = {word: math.log((k + 1) / (df[word] + 1)) + 1.0 for word in vocab}
q_vec = {word: q_tf.get(word, 0) * idf[word] for word in vocab}
q_norm = math.sqrt(sum(val**2 for val in q_vec.values()))
max_sim = -1.0
best_doc_id = -1
best_base_sim = -1.0
for j in range(k):
doc_idx = start_idx + j
current_doc_tf = window_docs_tf[j]
doc_vec = {word: current_doc_tf.get(word, 0) * idf[word] for word in vocab}
doc_norm = math.sqrt(sum(val**2 for val in doc_vec.values()))
dot_product = sum(q_vec[word] * doc_vec[word] for word in vocab)
cos_sim = 0.0
if q_norm > 1e-9 and doc_norm > 1e-9:
cos_sim = dot_product / (q_norm * doc_norm)
weight = (j + 1) / k
weighted_sim = cos_sim * weight
if weighted_sim > max_sim:
max_sim = weighted_sim
best_doc_id = doc_idx
best_base_sim = cos_sim
results.append(best_doc_id if best_base_sim >= 0.6 else -1)
print(*results)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:本题解法为模拟。对每次查询,都完整地执行一次窗口构建、词汇表生成、DF/IDF 计算、向量构建和相似度比较的流程。
- 时间复杂度:
,其中
是查询次数,
是窗口大小,
是文档和查询的平均长度(用于分词),
是每次查询窗口内形成的词汇表大小。对于每次查询,分词需要
,构建词汇表和计算 DF、IDF 需要
,计算向量和相似度也需要
。
- 空间复杂度:
,其中
用于存储所有文档,
用于存储窗口内文档的词频信息,
用于存储词汇表、DF 和 IDF 值。