标签在前K个近邻中的出现次数

KNN(K近邻)分类器的核心流程是什么?算距离、排序、投票。这题就是把这个流程走一遍。

思路

拿到测试样本后,对每个训练样本算欧氏距离的平方(不用开根号,不影响排序),然后按距离从小到大排序,取前 K 个邻居统计标签出现次数。

唯一需要注意的是平票处理:如果有多个标签出现次数并列最高,选哪个?题目规定选在这些平票标签中,最近的那个邻居所属的标签

具体做法:排好序以后,从近到远扫描 K 个邻居,第一个属于平票标签集合的邻居,它的标签就是答案。

为什么用距离平方而不是距离本身?因为 是单调递增函数,不影响大小关系,还省了开方的精度和性能损耗。

复杂度

  • 时间:,计算 个样本的距离各需 ,排序
  • 空间:,存储距离和标签

代码

import sys

def main():
    data = sys.stdin.read().split()
    idx = 0
    k = int(data[idx]); idx += 1
    m = int(data[idx]); idx += 1
    n = int(data[idx]); idx += 1
    s = int(data[idx]); idx += 1

    test = []
    for i in range(n):
        test.append(float(data[idx])); idx += 1

    samples = []
    for i in range(m):
        feats = []
        for j in range(n):
            feats.append(float(data[idx])); idx += 1
        label = int(float(data[idx])); idx += 1
        dist = sum((test[j] - feats[j]) ** 2 for j in range(n))
        samples.append((dist, i, label))

    samples.sort(key=lambda x: (x[0], x[1]))
    neighbors = samples[:k]

    from collections import Counter
    counts = Counter()
    for dist, i, label in neighbors:
        counts[label] += 1

    max_count = max(counts.values())
    tied_labels = {l for l, c in counts.items() if c == max_count}

    if len(tied_labels) == 1:
        print(tied_labels.pop(), max_count)
    else:
        for dist, i, label in neighbors:
            if label in tied_labels:
                print(label, max_count)
                break

main()
#include <bits/stdc++.h>
using namespace std;
int main(){
    int k, m, n, s;
    scanf("%d%d%d%d", &k, &m, &n, &s);
    vector<double> test(n);
    for(int i = 0; i < n; i++) scanf("%lf", &test[i]);
    vector<tuple<double,int,int>> samples;
    for(int i = 0; i < m; i++){
        vector<double> f(n);
        for(int j = 0; j < n; j++) scanf("%lf", &f[j]);
        double lb; scanf("%lf", &lb);
        int label = (int)lb;
        double dist = 0;
        for(int j = 0; j < n; j++) dist += (test[j]-f[j])*(test[j]-f[j]);
        samples.emplace_back(dist, i, label);
    }
    sort(samples.begin(), samples.end());
    map<int,int> cnt;
    for(int i = 0; i < k; i++) cnt[get<2>(samples[i])]++;
    int mx = 0;
    for(auto& p : cnt) mx = max(mx, p.second);
    set<int> tied;
    for(auto& p : cnt) if(p.second == mx) tied.insert(p.first);
    for(int i = 0; i < k; i++){
        if(tied.count(get<2>(samples[i]))){
            printf("%d %d\n", get<2>(samples[i]), mx);
            break;
        }
    }
}
import java.util.*;
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int k = sc.nextInt(), m = sc.nextInt(), n = sc.nextInt(), s = sc.nextInt();
        double[] test = new double[n];
        for(int i = 0; i < n; i++) test[i] = sc.nextDouble();
        double[] dists = new double[m];
        int[] labels = new int[m];
        Integer[] idx = new Integer[m];
        for(int i = 0; i < m; i++){
            double[] f = new double[n];
            for(int j = 0; j < n; j++) f[j] = sc.nextDouble();
            labels[i] = (int)sc.nextDouble();
            double d = 0;
            for(int j = 0; j < n; j++) d += (test[j]-f[j])*(test[j]-f[j]);
            dists[i] = d;
            idx[i] = i;
        }
        Arrays.sort(idx, (a, b) -> {
            int c = Double.compare(dists[a], dists[b]);
            return c != 0 ? c : Integer.compare(a, b);
        });
        Map<Integer,Integer> cnt = new HashMap<>();
        for(int i = 0; i < k; i++) cnt.merge(labels[idx[i]], 1, Integer::sum);
        int mx = Collections.max(cnt.values());
        Set<Integer> tied = new HashSet<>();
        for(Map.Entry<Integer,Integer> e : cnt.entrySet())
            if(e.getValue() == mx) tied.add(e.getKey());
        for(int i = 0; i < k; i++){
            if(tied.contains(labels[idx[i]])){
                System.out.println(labels[idx[i]] + " " + mx);
                break;
            }
        }
    }
}
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l.trim()));
rl.on('close', () => {
    const [k, m, n, s] = lines[0].split(/\s+/).map(Number);
    const test = lines[1].split(/\s+/).map(Number);
    const samples = [];
    for (let i = 0; i < m; i++) {
        const parts = lines[i + 2].split(/\s+/).map(Number);
        const feats = parts.slice(0, n);
        const label = Math.round(parts[n]);
        let dist = 0;
        for (let j = 0; j < n; j++) dist += (test[j] - feats[j]) ** 2;
        samples.push([dist, i, label]);
    }
    samples.sort((a, b) => a[0] - b[0] || a[1] - b[1]);
    const cnt = {};
    for (let i = 0; i < k; i++) {
        const lb = samples[i][2];
        cnt[lb] = (cnt[lb] || 0) + 1;
    }
    const mx = Math.max(...Object.values(cnt));
    const tied = new Set(Object.keys(cnt).filter(l => cnt[l] === mx).map(Number));
    for (let i = 0; i < k; i++) {
        if (tied.has(samples[i][2])) {
            console.log(samples[i][2] + " " + mx);
            break;
        }
    }
});