题目链接
题目描述
实现一个 K-近邻 (K-Nearest Neighbors, KNN) 分类器。给定一个待分类样本和多个带标签的训练样本,需要找出与待分类样本距离最近的 个训练样本,然后根据这
个“近邻”的标签,通过多数票原则确定待分类样本的类别。
具体规则如下:
- 计算待分类样本与每个训练样本的距离(为简化计算,可使用平方欧氏距离)。
- 将所有训练样本按距离升序排序。
- 选取前
个作为近邻。
- 统计这
个近邻中各个类别标签出现的次数。
- 出现次数最多的标签即为预测结果。
- 如果最高次数出现并列情况,则在这些并列的标签中,选择其对应近邻中距离最近的那个样本的标签作为最终结果。
解题思路
这是一个对 K-近邻算法的直接实现题,核心在于距离计算、排序和带特殊规则的投票。
算法步骤可以分解如下:
-
数据结构设计: 为了方便处理,我们可以定义一个结构体或类来存储每个训练样本的信息,至少包含:
- 原始的特征向量。
- 类别标签。
- 该样本与待测样本之间的距离(计算后填入)。
-
距离计算: 遍历所有
个训练样本。对于每个训练样本,计算它与待测样本之间的平方欧氏距离。 对于两个
维的点
和
,它们之间的平方欧氏距离为:
将计算出的距离存入对应样本的结构体中。
-
排序: 将所有训练样本(现在已包含距离信息)存储在一个列表中,并根据距离字段进行升序排序。
-
选取近邻并投票:
- 从排序后的列表中选取前
个样本,作为最近的邻居。
- 使用一个哈希表(
map
)来统计这个邻居中,每个类别标签出现的次数(频率)。
- 从排序后的列表中选取前
-
确定最终类别:
- 遍历频率哈希表,找出最高的出现次数
max_freq
。 - 找出所有出现次数等于
max_freq
的标签,将它们存入一个“并列候选”集合。 - 如果“并列候选”集合中只有一个标签,那么它就是最终的预测结果。
- 如果多于一个标签,则需要进行并列处理:回头查看原始的、按距离排好序的前
个邻居列表,从第一个(距离最近的)开始遍历,遇到的第一个其标签在“并列候选”集合中的邻居,它的标签就是最终的预测结果。
- 遍历频率哈希表,找出最高的出现次数
-
输出: 输出最终确定的类别标签,以及该标签在
个近邻中的出现次数(即
max_freq
)。
这个流程完整地实现了题目要求,特别是处理了并列情况的特殊规则。
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
#include <map>
#include <iomanip>
using namespace std;
struct Sample {
vector<double> features;
int label;
double distance;
};
// 计算平方欧氏距离
double squared_euclidean_distance(const vector<double>& a, const vector<double>& b) {
double dist = 0.0;
for (size_t i = 0; i < a.size(); ++i) {
dist += (a[i] - b[i]) * (a[i] - b[i]);
}
return dist;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int k, m, n, s;
cin >> k >> m >> n >> s;
vector<double> test_sample(n);
for (int i = 0; i < n; ++i) {
cin >> test_sample[i];
}
vector<Sample> training_samples(m);
for (int i = 0; i < m; ++i) {
training_samples[i].features.resize(n);
for (int j = 0; j < n; ++j) {
cin >> training_samples[i].features[j];
}
double label_float;
cin >> label_float;
training_samples[i].label = static_cast<int>(label_float);
training_samples[i].distance = squared_euclidean_distance(test_sample, training_samples[i].features);
}
// 按距离排序
stable_sort(training_samples.begin(), training_samples.end(), [](const Sample& a, const Sample& b) {
return a.distance < b.distance;
});
// 统计前k个邻居的标签频率
map<int, int> label_counts;
for (int i = 0; i < k; ++i) {
label_counts[training_samples[i].label]++;
}
int max_freq = 0;
for (auto const& [label, freq] : label_counts) {
if (freq > max_freq) {
max_freq = freq;
}
}
// 找出所有最高频率的标签
vector<int> tied_labels;
for (auto const& [label, freq] : label_counts) {
if (freq == max_freq) {
tied_labels.push_back(label);
}
}
int predicted_label;
if (tied_labels.size() == 1) {
predicted_label = tied_labels[0];
} else {
// 并列处理:按距离顺序找第一个
for (int i = 0; i < k; ++i) {
bool is_tied = false;
for (int label : tied_labels) {
if (training_samples[i].label == label) {
is_tied = true;
break;
}
}
if (is_tied) {
predicted_label = training_samples[i].label;
break;
}
}
}
cout << predicted_label << " " << max_freq << "\n";
return 0;
}
import java.util.*;
class Sample {
double[] features;
int label;
double distance;
public Sample(int n) {
this.features = new double[n];
}
}
public class Main {
// 计算平方欧氏距离
private static double squaredEuclideanDistance(double[] a, double[] b) {
double dist = 0.0;
for (int i = 0; i < a.length; i++) {
dist += (a[i] - b[i]) * (a[i] - b[i]);
}
return dist;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int k = sc.nextInt();
int m = sc.nextInt();
int n = sc.nextInt();
int s = sc.nextInt();
double[] testSample = new double[n];
for (int i = 0; i < n; i++) {
testSample[i] = sc.nextDouble();
}
List<Sample> trainingSamples = new ArrayList<>();
for (int i = 0; i < m; i++) {
Sample sample = new Sample(n);
for (int j = 0; j < n; j++) {
sample.features[j] = sc.nextDouble();
}
sample.label = (int) sc.nextDouble();
sample.distance = squaredEuclideanDistance(testSample, sample.features);
trainingSamples.add(sample);
}
// 按距离排序
trainingSamples.sort(Comparator.comparingDouble(smp -> smp.distance));
// 统计前k个邻居的标签频率
Map<Integer, Integer> labelCounts = new HashMap<>();
for (int i = 0; i < k; i++) {
int label = trainingSamples.get(i).label;
labelCounts.put(label, labelCounts.getOrDefault(label, 0) + 1);
}
int maxFreq = 0;
for (int freq : labelCounts.values()) {
if (freq > maxFreq) {
maxFreq = freq;
}
}
// 找出所有最高频率的标签
List<Integer> tiedLabels = new ArrayList<>();
for (Map.Entry<Integer, Integer> entry : labelCounts.entrySet()) {
if (entry.getValue() == maxFreq) {
tiedLabels.add(entry.getKey());
}
}
int predictedLabel;
if (tiedLabels.size() == 1) {
predictedLabel = tiedLabels.get(0);
} else {
// 并列处理
predictedLabel = -1; // 初始化
for (int i = 0; i < k; i++) {
if (tiedLabels.contains(trainingSamples.get(i).label)) {
predictedLabel = trainingSamples.get(i).label;
break;
}
}
}
System.out.println(predictedLabel + " " + maxFreq);
}
}
import sys
def main():
# 读取第一行输入
line1 = input().split()
k, m, n, s = map(int, line1)
# 读取待测样本
test_sample = list(map(float, input().split()))
training_samples = []
for _ in range(m):
line = list(map(float, input().split()))
features = line[:-1]
label = int(line[-1])
# 计算平方欧氏距离
distance = sum((test_sample[i] - features[i]) ** 2 for i in range(n))
training_samples.append({'features': features, 'label': label, 'distance': distance})
# 按距离排序
training_samples.sort(key=lambda x: x['distance'])
# 选取k个近邻
k_neighbors = training_samples[:k]
# 统计标签频率
label_counts = {}
for neighbor in k_neighbors:
label = neighbor['label']
label_counts[label] = label_counts.get(label, 0) + 1
if not label_counts:
# 如果k=0或没有邻居,虽然题目约束不会发生,但做好保护
return
# 找出最高频率
max_freq = 0
for freq in label_counts.values():
if freq > max_freq:
max_freq = freq
# 找出所有最高频率的标签
tied_labels = [label for label, freq in label_counts.items() if freq == max_freq]
predicted_label = -1
if len(tied_labels) == 1:
predicted_label = tied_labels[0]
else:
# 并列处理
for neighbor in k_neighbors:
if neighbor['label'] in tied_labels:
predicted_label = neighbor['label']
break
print(f"{predicted_label} {max_freq}")
if __name__ == "__main__":
main()
算法及复杂度
- 算法:K-近邻 (K-Nearest Neighbors, KNN)
- 时间复杂度:
,其中
是训练样本数,
是特征维度。
- 计算所有样本与待测点的距离需要
。
- 对
个样本按距离排序需要
。
- 统计和并列处理需要
,其中
。
- 因此,总时间复杂度由距离计算和排序主导。
- 计算所有样本与待测点的距离需要
- 空间复杂度:
,主要用于存储所有训练样本的特征数据。