找出相似度最高的文档
题意
有 篇文档按时间顺序排列(编号
到
),每篇由若干小写单词组成。给定窗口大小
和
个查询,每个查询包含时刻
和查询短语
。
对于每个查询,在时刻 及之前的最近
篇文档中,用 TF-IDF 向量的加权余弦相似度找出最相关的文档。具体规则:
- 窗口:文档
- TF:词在文档中的出现次数
- IDF:
,其中
是窗口内包含词
的文档数
- 时间权重:窗口内从旧到新第
个文档(
从
开始),权重为
,越新越大
- 评分:先算查询与文档的原始余弦相似度
,仅保留
的候选,再以
排名
- 平局:取窗口中最早的文档
- 无候选则输出
思路
拿到这题先别慌,虽然公式一大堆,但本质就是一个模拟题,把每一步老老实实算出来就行。
第一个问题:窗口怎么取?
文档编号 到
,查询时刻
可能
(比如样例里
但只有
篇文档),所以窗口右端点是
,左端点往前推
个位置。
第二个问题:TF-IDF 怎么算?
TF 就是词频,直接统计。IDF 的分母用的是窗口内包含该词的文档数(不是全局的),注意 在公式里代表的是窗口大小
,不是总文档数。
第三个问题:时间权重到底怎么参与计算?
这是本题最容易搞混的地方。如果把时间权重乘进文档向量再算余弦,由于标量乘法会被范数消掉,等价于不加权——那就没意义了。
实际做法是:先算原始的余弦相似度 ,用
做门槛过滤,然后用
作为最终得分来排名。这样新文档天然有加分,老文档被打折。
第四个问题:平局怎么处理?
得分相同时选窗口中最早的那篇。我们从旧到新遍历,严格大于才更新,自然就保住了最早的。
时间复杂度 ,
是文档平均长度,足够通过。
代码
import math
import sys
from collections import Counter
def solve():
input_data = sys.stdin.read().split('\n')
idx = 0
N = int(input_data[idx]); idx += 1
docs = []
for i in range(N):
tokens = input_data[idx].strip().split()
docs.append(tokens)
idx += 1
K = int(input_data[idx]); idx += 1
P = int(input_data[idx]); idx += 1
results = []
for _ in range(P):
line = input_data[idx].strip().split(); idx += 1
t = int(line[0])
query_tokens = line[1:]
window_end = min(t, N - 1)
window_start = window_end - K + 1
# 统计窗口内每个词的文档频率
df = Counter()
for i in range(window_start, window_end + 1):
for w in set(docs[i]):
df[w] += 1
def idf(word):
return math.log((K + 1) / (df.get(word, 0) + 1)) + 1
# 查询的 TF-IDF 向量
query_tf = Counter(query_tokens)
query_vec = {w: query_tf[w] * idf(w) for w in query_tf}
q_norm = math.sqrt(sum(v * v for v in query_vec.values()))
best_score = -1
best_doc_id = -1
for j in range(K):
doc_id = window_start + j
time_weight = (j + 1) / K
doc_tf = Counter(docs[doc_id])
doc_vec = {w: doc_tf[w] * idf(w) for w in set(docs[doc_id])}
dot = sum(query_vec[w] * doc_vec[w] for w in query_vec if w in doc_vec)
if dot == 0:
continue
d_norm = math.sqrt(sum(v * v for v in doc_vec.values()))
cosine = dot / (q_norm * d_norm)
if cosine < 0.6:
continue
score = time_weight * cosine
if score > best_score:
best_score = score
best_doc_id = doc_id
results.append(str(best_doc_id))
print(' '.join(results))
solve()
#include <bits/stdc++.h>
using namespace std;
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
cin >> N;
cin.ignore();
vector<vector<string>> docs(N);
for(int i = 0; i < N; i++){
string line;
getline(cin, line);
istringstream iss(line);
string w;
while(iss >> w) docs[i].push_back(w);
}
int K, P;
cin >> K >> P;
cin.ignore();
vector<string> results;
for(int q = 0; q < P; q++){
string line;
getline(cin, line);
istringstream iss(line);
int t;
iss >> t;
vector<string> query;
string w;
while(iss >> w) query.push_back(w);
int window_end = min(t, N - 1);
int window_start = window_end - K + 1;
unordered_map<string, int> df;
for(int i = window_start; i <= window_end; i++){
unordered_set<string> seen(docs[i].begin(), docs[i].end());
for(auto& s : seen) df[s]++;
}
auto idf = [&](const string& word) -> double {
int d = df.count(word) ? df[word] : 0;
return log((double)(K + 1) / (d + 1)) + 1.0;
};
unordered_map<string, double> qvec;
unordered_map<string, int> qtf;
for(auto& w : query) qtf[w]++;
double qnorm = 0;
for(auto& [word, cnt] : qtf){
double v = cnt * idf(word);
qvec[word] = v;
qnorm += v * v;
}
qnorm = sqrt(qnorm);
double best_score = -1;
int best_id = -1;
for(int j = 0; j < K; j++){
int doc_id = window_start + j;
double tw = (double)(j + 1) / K;
unordered_map<string, int> dtf;
for(auto& w : docs[doc_id]) dtf[w]++;
double dot = 0, dnorm = 0;
for(auto& [word, cnt] : dtf){
double v = cnt * idf(word);
dnorm += v * v;
if(qvec.count(word)) dot += qvec[word] * v;
}
if(dot == 0) continue;
dnorm = sqrt(dnorm);
double cosine = dot / (qnorm * dnorm);
if(cosine < 0.6) continue;
double score = tw * cosine;
if(score > best_score){
best_score = score;
best_id = doc_id;
}
}
results.push_back(to_string(best_id));
}
for(int i = 0; i < (int)results.size(); i++){
if(i) cout << ' ';
cout << results[i];
}
cout << '\n';
return 0;
}

京公网安备 11010502036488号