题目链接

找出相似度最高的文档

题目描述

为了追踪突发热点,需要在“查询时刻 之前的最近 篇文档”内,根据加权余弦相似度挑选最相关的文档。

具体规则如下:

  1. 查询窗口:对于查询时刻 ,窗口为 min(t, N-1) 号文档及其之前的总共 篇文档。
  2. TF-IDF 词向量
    • 词频 TF() 是词语 在文档 中出现的次数。
    • 逆文档频率 IDF() 计算采用平滑公式:,其中 是窗口内的文档总数(即 ), 是窗口内包含词语 的文档数。
  3. 加权余弦相似度
    • 查询 与文档 的原始余弦相似度为
    • 窗口内从旧到新(即文档编号从小到大)的第 篇文档(),其时间权重为
    • 最终相似度 = 原始余弦相似度 时间权重。
  4. 筛选规则
    • 找出相似度 且最高的文档。
    • 若存在多个最高分,返回窗口中最早的(即编号最小的)文档。
    • 若没有满足条件的文档,输出 -1。

输入:

  • 文档总数
  • 行文档内容
  • 窗口大小
  • 查询总数
  • 行查询,每行格式为 “

输出:

  • 个数字,表示每次查询的结果。

解题思路

本题是一道复杂的模拟题,核心是为每次查询动态计算窗口内文档的 TF-IDF 向量,并据此计算加权余弦相似度。需要严格按照题目定义的公式和流程进行计算。

对于每一次查询 ,算法步骤如下:

  1. 确定查询窗口: 根据题目描述和示例推断,查询时刻 对应的窗口是文档编号从 end_idx - K + 1end_idx 的文档,其中 end_idx = min(t, N-1)

  2. 处理查询和文档: 将查询短语 和窗口内的 篇文档都进行分词,并统计每个词的词频(TF)。使用 map<string, int> 来存储词频。

  3. 构建词汇表和计算 DF: 遍历查询 和窗口内所有文档中的每一个词,构建当前查询的词汇表。同时,计算词汇表中每个词的文档频率 ,即它在窗口内多少个文档中出现过。

  4. 计算 TF-IDF 向量: 对于词汇表中的每一个词

    • 计算其 IDF 值:
    • 查询向量 维度上的分量为
    • 窗口内每个文档 维度上的分量为
    • 将这些分量存储在 map<string, double> 结构的向量中。
  5. 计算加权余弦相似度: 遍历窗口内的每一篇文档 (其中 start_idxend_idx):

    • 计算查询向量 和文档向量 的点积
    • 计算两个向量的 L2 范数(模长)
    • 计算原始余弦相似度:。如果分母为 0,则相似度为 0。
    • 计算时间权重。窗口中第 篇文档()的文档编号为 start_idx + j - 1。其权重为
    • 最终相似度
  6. 筛选最佳文档

    • 维护一个变量 max_sim 记录最高相似度(初始化为 -1.0),best_doc_id 记录最佳文档编号(初始化为 -1)。
    • 遍历计算出的每个文档的最终相似度
      • 如果 ,则更新 max_sim = S_finalbest_doc_id = i
    • 根据规则,若相似度并列,返回编号最小的文档。由于我们是按文档编号从小到大的顺序遍历的,所以只有在严格大于当前最大相似度时才更新,天然地满足了这一要求。
  7. 输出结果: 完成对一个查询的处理后,输出 best_doc_id。对所有 次查询重复以上步骤。

代码

#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <map>
#include <cmath>
#include <algorithm>
#include <iomanip>
#include <set>

using namespace std;

// 分词并计算词频
map<string, int> get_tf(const string& text) {
    map<string, int> tf;
    stringstream ss(text);
    string word;
    while (ss >> word) {
        tf[word]++;
    }
    return tf;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n_total;
    cin >> n_total;
    cin.ignore(); 

    vector<string> docs(n_total);
    for (int i = 0; i < n_total; ++i) {
        getline(cin, docs[i]);
    }

    int k;
    cin >> k;

    int p;
    cin >> p;
    cin.ignore();

    vector<int> results;
    for (int i = 0; i < p; ++i) {
        string line;
        getline(cin, line);
        stringstream ss(line);
        int t;
        ss >> t;
        string q_text;
        getline(ss, q_text);
        // 移除前导空格
        if (!q_text.empty() && q_text[0] == ' ') {
            q_text = q_text.substr(1);
        }

        int end_idx = min(t, n_total - 1);
        int start_idx = end_idx - k + 1;

        vector<map<string, int>> window_docs_tf;
        for (int j = start_idx; j <= end_idx; ++j) {
            window_docs_tf.push_back(get_tf(docs[j]));
        }
        map<string, int> q_tf = get_tf(q_text);

        set<string> vocab;
        for (const auto& pair : q_tf) {
            vocab.insert(pair.first);
        }
        for (const auto& doc_tf : window_docs_tf) {
            for (const auto& pair : doc_tf) {
                vocab.insert(pair.first);
            }
        }

        map<string, int> df;
        for (const string& word : vocab) {
            df[word] = 0;
            for (const auto& doc_tf : window_docs_tf) {
                if (doc_tf.count(word)) {
                    df[word]++;
                }
            }
        }
        
        map<string, double> idf;
        for (const string& word : vocab) {
            idf[word] = log((double)(k + 1) / (df[word] + 1)) + 1.0;
        }

        map<string, double> q_vec;
        double q_norm = 0.0;
        for (const string& word : vocab) {
            double tf_val = q_tf.count(word) ? q_tf[word] : 0;
            q_vec[word] = tf_val * idf[word];
            q_norm += q_vec[word] * q_vec[word];
        }
        q_norm = sqrt(q_norm);

        double max_sim = -1.0;
        int best_doc_id = -1;
        double best_doc_base_sim = -1.0;

        for (int j = 0; j < k; ++j) {
            int doc_idx = start_idx + j;
            map<string, double> doc_vec;
            double doc_norm = 0.0;
            double dot_product = 0.0;

            for (const string& word : vocab) {
                double tf_val = window_docs_tf[j].count(word) ? window_docs_tf[j][word] : 0;
                doc_vec[word] = tf_val * idf[word];
                doc_norm += doc_vec[word] * doc_vec[word];
                dot_product += q_vec[word] * doc_vec[word];
            }
            doc_norm = sqrt(doc_norm);
            
            double cos_sim = 0.0;
            if (q_norm > 1e-9 && doc_norm > 1e-9) {
                cos_sim = dot_product / (q_norm * doc_norm);
            }

            double weight = (double)(j + 1) / k;
            double weighted_sim = cos_sim * weight;
            
            if (weighted_sim > max_sim) {
                max_sim = weighted_sim;
                best_doc_id = doc_idx;
                best_doc_base_sim = cos_sim;
            }
        }
        results.push_back(best_doc_base_sim >= 0.6 ? best_doc_id : -1);
    }

    for (int i = 0; i < results.size(); ++i) {
        cout << results[i] << (i == results.size() - 1 ? "" : " ");
    }
    cout << endl;

    return 0;
}
import java.util.*;
import java.io.*;
import java.lang.Math;

public class Main {
    // 分词并计算词频
    private static Map<String, Integer> getTf(String text) {
        Map<String, Integer> tf = new HashMap<>();
        String[] words = text.split("\\s+");
        for (String word : words) {
            if (!word.isEmpty()) {
                tf.put(word, tf.getOrDefault(word, 0) + 1);
            }
        }
        return tf;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int nTotal = sc.nextInt();
        sc.nextLine(); 

        List<String> docs = new ArrayList<>();
        for (int i = 0; i < nTotal; ++i) {
            docs.add(sc.nextLine());
        }

        int k = sc.nextInt();
        int p = sc.nextInt();
        sc.nextLine();

        List<Integer> results = new ArrayList<>();
        for (int i = 0; i < p; ++i) {
            int t = sc.nextInt();
            String qText = sc.nextLine().trim();
            
            int endIdx = Math.min(t, nTotal - 1);
            int startIdx = endIdx - k + 1;

            List<Map<String, Integer>> windowDocsTf = new ArrayList<>();
            for (int j = startIdx; j <= endIdx; ++j) {
                windowDocsTf.add(getTf(docs.get(j)));
            }
            Map<String, Integer> qTf = getTf(qText);

            Set<String> vocab = new HashSet<>();
            vocab.addAll(qTf.keySet());
            for (Map<String, Integer> docTf : windowDocsTf) {
                vocab.addAll(docTf.keySet());
            }

            Map<String, Integer> df = new HashMap<>();
            for (String word : vocab) {
                df.put(word, 0);
                for (Map<String, Integer> docTf : windowDocsTf) {
                    if (docTf.containsKey(word)) {
                        df.put(word, df.get(word) + 1);
                    }
                }
            }
            
            Map<String, Double> idf = new HashMap<>();
            for (String word : vocab) {
                idf.put(word, Math.log((double)(k + 1) / (df.get(word) + 1)) + 1.0);
            }

            Map<String, Double> qVec = new HashMap<>();
            double qNorm = 0.0;
            for (String word : vocab) {
                double tfVal = qTf.getOrDefault(word, 0);
                double tfIdfVal = tfVal * idf.get(word);
                qVec.put(word, tfIdfVal);
                qNorm += tfIdfVal * tfIdfVal;
            }
            qNorm = Math.sqrt(qNorm);

            double maxSim = -1.0;
            int bestDocId = -1;
            double bestDocBaseSim = -1.0;

            for (int j = 0; j < k; ++j) {
                int docIdx = startIdx + j;
                Map<String, Integer> currentDocTf = windowDocsTf.get(j);
                double docNorm = 0.0;
                double dotProduct = 0.0;

                for (String word : vocab) {
                    double tfVal = currentDocTf.getOrDefault(word, 0);
                    double tfIdfVal = tfVal * idf.get(word);
                    docNorm += tfIdfVal * tfIdfVal;
                    dotProduct += qVec.get(word) * tfIdfVal;
                }
                docNorm = Math.sqrt(docNorm);
                
                double cosSim = 0.0;
                if (qNorm > 1e-9 && docNorm > 1e-9) {
                    cosSim = dotProduct / (qNorm * docNorm);
                }

                double weight = (double)(j + 1) / k;
                double weightedSim = cosSim * weight;
                
                if (weightedSim > maxSim) {
                    maxSim = weightedSim;
                    bestDocId = docIdx;
                    bestDocBaseSim = cosSim;
                }
            }
            results.add(bestDocBaseSim >= 0.6 ? bestDocId : -1);
        }

        for (int i = 0; i < results.size(); ++i) {
            System.out.print(results.get(i) + (i == results.size() - 1 ? "" : " "));
        }
        System.out.println();
    }
}
import math
from collections import Counter

def get_tf(text):
    return Counter(text.split())

def main():
    n_total = int(input())
    docs = [input() for _ in range(n_total)]
    k = int(input())
    p = int(input())

    results = []
    for _ in range(p):
        line = input().split(maxsplit=1)
        t = int(line[0])
        q_text = line[1] if len(line) > 1 else ""

        end_idx = min(t, n_total - 1)
        start_idx = end_idx - k + 1
        
        window_docs = docs[start_idx : end_idx + 1]
        window_docs_tf = [get_tf(doc) for doc in window_docs]
        q_tf = get_tf(q_text)

        vocab = set(q_tf.keys())
        for doc_tf in window_docs_tf:
            vocab.update(doc_tf.keys())

        df = {word: 0 for word in vocab}
        for word in vocab:
            for doc_tf in window_docs_tf:
                if word in doc_tf:
                    df[word] += 1
        
        idf = {word: math.log((k + 1) / (df[word] + 1)) + 1.0 for word in vocab}

        q_vec = {word: q_tf.get(word, 0) * idf[word] for word in vocab}
        q_norm = math.sqrt(sum(val**2 for val in q_vec.values()))

        max_sim = -1.0
        best_doc_id = -1
        best_base_sim = -1.0

        for j in range(k):
            doc_idx = start_idx + j
            current_doc_tf = window_docs_tf[j]
            
            doc_vec = {word: current_doc_tf.get(word, 0) * idf[word] for word in vocab}
            doc_norm = math.sqrt(sum(val**2 for val in doc_vec.values()))
            
            dot_product = sum(q_vec[word] * doc_vec[word] for word in vocab)
            
            cos_sim = 0.0
            if q_norm > 1e-9 and doc_norm > 1e-9:
                cos_sim = dot_product / (q_norm * doc_norm)

            weight = (j + 1) / k
            weighted_sim = cos_sim * weight

            if weighted_sim > max_sim:
                max_sim = weighted_sim
                best_doc_id = doc_idx
                best_base_sim = cos_sim

        results.append(best_doc_id if best_base_sim >= 0.6 else -1)

    print(*results)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:本题解法为模拟。对每次查询,都完整地执行一次窗口构建、词汇表生成、DF/IDF 计算、向量构建和相似度比较的流程。
  • 时间复杂度:,其中 是查询次数, 是窗口大小, 是文档和查询的平均长度(用于分词), 是每次查询窗口内形成的词汇表大小。对于每次查询,分词需要 ,构建词汇表和计算 DF、IDF 需要 ,计算向量和相似度也需要
  • 空间复杂度:,其中 用于存储所有文档, 用于存储窗口内文档的词频信息, 用于存储词汇表、DF 和 IDF 值。