标签在前K个近邻中的出现次数
KNN(K近邻)分类器的核心流程是什么?算距离、排序、投票。这题就是把这个流程走一遍。
思路
拿到测试样本后,对每个训练样本算欧氏距离的平方(不用开根号,不影响排序),然后按距离从小到大排序,取前 K 个邻居统计标签出现次数。
唯一需要注意的是平票处理:如果有多个标签出现次数并列最高,选哪个?题目规定选在这些平票标签中,最近的那个邻居所属的标签。
具体做法:排好序以后,从近到远扫描 K 个邻居,第一个属于平票标签集合的邻居,它的标签就是答案。
为什么用距离平方而不是距离本身?因为 是单调递增函数,不影响大小关系,还省了开方的精度和性能损耗。
复杂度
- 时间:
,计算
个样本的距离各需
,排序
- 空间:
,存储距离和标签
代码
import sys
def main():
data = sys.stdin.read().split()
idx = 0
k = int(data[idx]); idx += 1
m = int(data[idx]); idx += 1
n = int(data[idx]); idx += 1
s = int(data[idx]); idx += 1
test = []
for i in range(n):
test.append(float(data[idx])); idx += 1
samples = []
for i in range(m):
feats = []
for j in range(n):
feats.append(float(data[idx])); idx += 1
label = int(float(data[idx])); idx += 1
dist = sum((test[j] - feats[j]) ** 2 for j in range(n))
samples.append((dist, i, label))
samples.sort(key=lambda x: (x[0], x[1]))
neighbors = samples[:k]
from collections import Counter
counts = Counter()
for dist, i, label in neighbors:
counts[label] += 1
max_count = max(counts.values())
tied_labels = {l for l, c in counts.items() if c == max_count}
if len(tied_labels) == 1:
print(tied_labels.pop(), max_count)
else:
for dist, i, label in neighbors:
if label in tied_labels:
print(label, max_count)
break
main()
#include <bits/stdc++.h>
using namespace std;
int main(){
int k, m, n, s;
scanf("%d%d%d%d", &k, &m, &n, &s);
vector<double> test(n);
for(int i = 0; i < n; i++) scanf("%lf", &test[i]);
vector<tuple<double,int,int>> samples;
for(int i = 0; i < m; i++){
vector<double> f(n);
for(int j = 0; j < n; j++) scanf("%lf", &f[j]);
double lb; scanf("%lf", &lb);
int label = (int)lb;
double dist = 0;
for(int j = 0; j < n; j++) dist += (test[j]-f[j])*(test[j]-f[j]);
samples.emplace_back(dist, i, label);
}
sort(samples.begin(), samples.end());
map<int,int> cnt;
for(int i = 0; i < k; i++) cnt[get<2>(samples[i])]++;
int mx = 0;
for(auto& p : cnt) mx = max(mx, p.second);
set<int> tied;
for(auto& p : cnt) if(p.second == mx) tied.insert(p.first);
for(int i = 0; i < k; i++){
if(tied.count(get<2>(samples[i]))){
printf("%d %d\n", get<2>(samples[i]), mx);
break;
}
}
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int k = sc.nextInt(), m = sc.nextInt(), n = sc.nextInt(), s = sc.nextInt();
double[] test = new double[n];
for(int i = 0; i < n; i++) test[i] = sc.nextDouble();
double[] dists = new double[m];
int[] labels = new int[m];
Integer[] idx = new Integer[m];
for(int i = 0; i < m; i++){
double[] f = new double[n];
for(int j = 0; j < n; j++) f[j] = sc.nextDouble();
labels[i] = (int)sc.nextDouble();
double d = 0;
for(int j = 0; j < n; j++) d += (test[j]-f[j])*(test[j]-f[j]);
dists[i] = d;
idx[i] = i;
}
Arrays.sort(idx, (a, b) -> {
int c = Double.compare(dists[a], dists[b]);
return c != 0 ? c : Integer.compare(a, b);
});
Map<Integer,Integer> cnt = new HashMap<>();
for(int i = 0; i < k; i++) cnt.merge(labels[idx[i]], 1, Integer::sum);
int mx = Collections.max(cnt.values());
Set<Integer> tied = new HashSet<>();
for(Map.Entry<Integer,Integer> e : cnt.entrySet())
if(e.getValue() == mx) tied.add(e.getKey());
for(int i = 0; i < k; i++){
if(tied.contains(labels[idx[i]])){
System.out.println(labels[idx[i]] + " " + mx);
break;
}
}
}
}
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l.trim()));
rl.on('close', () => {
const [k, m, n, s] = lines[0].split(/\s+/).map(Number);
const test = lines[1].split(/\s+/).map(Number);
const samples = [];
for (let i = 0; i < m; i++) {
const parts = lines[i + 2].split(/\s+/).map(Number);
const feats = parts.slice(0, n);
const label = Math.round(parts[n]);
let dist = 0;
for (let j = 0; j < n; j++) dist += (test[j] - feats[j]) ** 2;
samples.push([dist, i, label]);
}
samples.sort((a, b) => a[0] - b[0] || a[1] - b[1]);
const cnt = {};
for (let i = 0; i < k; i++) {
const lb = samples[i][2];
cnt[lb] = (cnt[lb] || 0) + 1;
}
const mx = Math.max(...Object.values(cnt));
const tied = new Set(Object.keys(cnt).filter(l => cnt[l] === mx).map(Number));
for (let i = 0; i < k; i++) {
if (tied.has(samples[i][2])) {
console.log(samples[i][2] + " " + mx);
break;
}
}
});

京公网安备 11010502036488号