基于卡方检验的特征选择实现

题意

给定 篇文档,每篇标注为 positivenegative,以及一个整数 。对所有出现过的单词(区分大小写)计算卡方统计量,输出卡方值最大的前 个单词。卡方值相同时按字母序升序排列。

思路

经典的卡方特征选择。对每个单词 构建 列联表:

类别为正 类别为负
不含

其中 表示「包含该词且类别为 positive 的文档数」, 表示「包含该词且类别为 negative 的文档数」,

卡方统计量公式为:

$$

注意这里是按 文档级别 统计「词是否出现」,而不是统计词频。同一篇文档中同一个词出现多次只算一次(用 Set 去重)。

实现要点:

  1. 读入每篇文档时,将内容按空格分词并放入 HashSet 去重,同时记录类别。
  2. 遍历所有出现过的单词,统计 ,进而算出
  3. 降序、字母序升序排序,输出前 个。

代码

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = Integer.parseInt(sc.nextLine().trim());

        String[] categories = new String[n];
        List<Set<String>> docWords = new ArrayList<>();
        Set<String> allWords = new HashSet<>();
        int posCount = 0;

        for (int i = 0; i < n; i++) {
            String line = sc.nextLine();
            int tab = line.indexOf('\t');
            categories[i] = line.substring(0, tab);
            if (categories[i].equals("positive")) posCount++;
            String content = line.substring(tab + 1);
            Set<String> words = new HashSet<>(Arrays.asList(content.split(" ")));
            docWords.add(words);
            allWords.addAll(words);
        }
        int k = Integer.parseInt(sc.nextLine().trim());
        int negCount = n - posCount;

        List<String[]> results = new ArrayList<>();
        for (String word : allWords) {
            int A = 0, B = 0;
            for (int i = 0; i < n; i++) {
                if (docWords.get(i).contains(word)) {
                    if (categories[i].equals("positive")) A++;
                    else B++;
                }
            }
            int C = posCount - A;
            int D = negCount - B;
            long denom = (long)(A + B) * (C + D) * (A + C) * (B + D);
            double chi2 = 0;
            if (denom > 0) {
                long num = (long) n * ((long) A * D - (long) B * C)
                         * ((long) A * D - (long) B * C);
                chi2 = (double) num / denom;
            }
            results.add(new String[]{word, String.valueOf(chi2)});
        }

        results.sort((a, b) -> {
            double va = Double.parseDouble(a[1]);
            double vb = Double.parseDouble(b[1]);
            if (va != vb) return Double.compare(vb, va);
            return a[0].compareTo(b[0]);
        });

        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < k; i++) {
            sb.append(results.get(i)[0]).append('\n');
        }
        System.out.print(sb);
    }
}

复杂度

  • 时间复杂度:,其中 为不同单词的总数。对每个单词遍历所有文档统计频次。
  • 空间复杂度:,存储每篇文档的词集合。