题目链接

标签在前K个近邻中的出现次数

题目描述

实现一个 K-近邻 (K-Nearest Neighbors, KNN) 分类器。给定一个待分类样本和多个带标签的训练样本,需要找出与待分类样本距离最近的 个训练样本,然后根据这 个“近邻”的标签,通过多数票原则确定待分类样本的类别。

具体规则如下:

  1. 计算待分类样本与每个训练样本的距离(为简化计算,可使用平方欧氏距离)。
  2. 将所有训练样本按距离升序排序。
  3. 选取前 个作为近邻。
  4. 统计这 个近邻中各个类别标签出现的次数。
  5. 出现次数最多的标签即为预测结果。
  6. 如果最高次数出现并列情况,则在这些并列的标签中,选择其对应近邻中距离最近的那个样本的标签作为最终结果。

解题思路

这是一个对 K-近邻算法的直接实现题,核心在于距离计算、排序和带特殊规则的投票。

算法步骤可以分解如下:

  1. 数据结构设计: 为了方便处理,我们可以定义一个结构体或类来存储每个训练样本的信息,至少包含:

    • 原始的特征向量。
    • 类别标签。
    • 该样本与待测样本之间的距离(计算后填入)。
  2. 距离计算: 遍历所有 个训练样本。对于每个训练样本,计算它与待测样本之间的平方欧氏距离。 对于两个 维的点 ,它们之间的平方欧氏距离为: 将计算出的距离存入对应样本的结构体中。

  3. 排序: 将所有训练样本(现在已包含距离信息)存储在一个列表中,并根据距离字段进行升序排序。

  4. 选取近邻并投票

    • 从排序后的列表中选取前 个样本,作为最近的邻居。
    • 使用一个哈希表(map)来统计这 个邻居中,每个类别标签出现的次数(频率)。
  5. 确定最终类别

    • 遍历频率哈希表,找出最高的出现次数 max_freq
    • 找出所有出现次数等于 max_freq 的标签,将它们存入一个“并列候选”集合。
    • 如果“并列候选”集合中只有一个标签,那么它就是最终的预测结果。
    • 如果多于一个标签,则需要进行并列处理:回头查看原始的、按距离排好序的前 个邻居列表,从第一个(距离最近的)开始遍历,遇到的第一个其标签在“并列候选”集合中的邻居,它的标签就是最终的预测结果。
  6. 输出: 输出最终确定的类别标签,以及该标签在 个近邻中的出现次数(即 max_freq)。

这个流程完整地实现了题目要求,特别是处理了并列情况的特殊规则。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
#include <map>
#include <iomanip>

using namespace std;

struct Sample {
    vector<double> features;
    int label;
    double distance;
};

// 计算平方欧氏距离
double squared_euclidean_distance(const vector<double>& a, const vector<double>& b) {
    double dist = 0.0;
    for (size_t i = 0; i < a.size(); ++i) {
        dist += (a[i] - b[i]) * (a[i] - b[i]);
    }
    return dist;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int k, m, n, s;
    cin >> k >> m >> n >> s;

    vector<double> test_sample(n);
    for (int i = 0; i < n; ++i) {
        cin >> test_sample[i];
    }

    vector<Sample> training_samples(m);
    for (int i = 0; i < m; ++i) {
        training_samples[i].features.resize(n);
        for (int j = 0; j < n; ++j) {
            cin >> training_samples[i].features[j];
        }
        double label_float;
        cin >> label_float;
        training_samples[i].label = static_cast<int>(label_float);
        training_samples[i].distance = squared_euclidean_distance(test_sample, training_samples[i].features);
    }

    // 按距离排序
    stable_sort(training_samples.begin(), training_samples.end(), [](const Sample& a, const Sample& b) {
        return a.distance < b.distance;
    });

    // 统计前k个邻居的标签频率
    map<int, int> label_counts;
    for (int i = 0; i < k; ++i) {
        label_counts[training_samples[i].label]++;
    }

    int max_freq = 0;
    for (auto const& [label, freq] : label_counts) {
        if (freq > max_freq) {
            max_freq = freq;
        }
    }

    // 找出所有最高频率的标签
    vector<int> tied_labels;
    for (auto const& [label, freq] : label_counts) {
        if (freq == max_freq) {
            tied_labels.push_back(label);
        }
    }

    int predicted_label;
    if (tied_labels.size() == 1) {
        predicted_label = tied_labels[0];
    } else {
        // 并列处理:按距离顺序找第一个
        for (int i = 0; i < k; ++i) {
            bool is_tied = false;
            for (int label : tied_labels) {
                if (training_samples[i].label == label) {
                    is_tied = true;
                    break;
                }
            }
            if (is_tied) {
                predicted_label = training_samples[i].label;
                break;
            }
        }
    }

    cout << predicted_label << " " << max_freq << "\n";

    return 0;
}
import java.util.*;

class Sample {
    double[] features;
    int label;
    double distance;

    public Sample(int n) {
        this.features = new double[n];
    }
}

public class Main {
    // 计算平方欧氏距离
    private static double squaredEuclideanDistance(double[] a, double[] b) {
        double dist = 0.0;
        for (int i = 0; i < a.length; i++) {
            dist += (a[i] - b[i]) * (a[i] - b[i]);
        }
        return dist;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int k = sc.nextInt();
        int m = sc.nextInt();
        int n = sc.nextInt();
        int s = sc.nextInt();

        double[] testSample = new double[n];
        for (int i = 0; i < n; i++) {
            testSample[i] = sc.nextDouble();
        }

        List<Sample> trainingSamples = new ArrayList<>();
        for (int i = 0; i < m; i++) {
            Sample sample = new Sample(n);
            for (int j = 0; j < n; j++) {
                sample.features[j] = sc.nextDouble();
            }
            sample.label = (int) sc.nextDouble();
            sample.distance = squaredEuclideanDistance(testSample, sample.features);
            trainingSamples.add(sample);
        }

        // 按距离排序
        trainingSamples.sort(Comparator.comparingDouble(smp -> smp.distance));

        // 统计前k个邻居的标签频率
        Map<Integer, Integer> labelCounts = new HashMap<>();
        for (int i = 0; i < k; i++) {
            int label = trainingSamples.get(i).label;
            labelCounts.put(label, labelCounts.getOrDefault(label, 0) + 1);
        }

        int maxFreq = 0;
        for (int freq : labelCounts.values()) {
            if (freq > maxFreq) {
                maxFreq = freq;
            }
        }

        // 找出所有最高频率的标签
        List<Integer> tiedLabels = new ArrayList<>();
        for (Map.Entry<Integer, Integer> entry : labelCounts.entrySet()) {
            if (entry.getValue() == maxFreq) {
                tiedLabels.add(entry.getKey());
            }
        }

        int predictedLabel;
        if (tiedLabels.size() == 1) {
            predictedLabel = tiedLabels.get(0);
        } else {
            // 并列处理
            predictedLabel = -1; // 初始化
            for (int i = 0; i < k; i++) {
                if (tiedLabels.contains(trainingSamples.get(i).label)) {
                    predictedLabel = trainingSamples.get(i).label;
                    break;
                }
            }
        }

        System.out.println(predictedLabel + " " + maxFreq);
    }
}
import sys

def main():
    # 读取第一行输入
    line1 = input().split()
    k, m, n, s = map(int, line1)

    # 读取待测样本
    test_sample = list(map(float, input().split()))

    training_samples = []
    for _ in range(m):
        line = list(map(float, input().split()))
        features = line[:-1]
        label = int(line[-1])
        
        # 计算平方欧氏距离
        distance = sum((test_sample[i] - features[i]) ** 2 for i in range(n))
        training_samples.append({'features': features, 'label': label, 'distance': distance})

    # 按距离排序
    training_samples.sort(key=lambda x: x['distance'])

    # 选取k个近邻
    k_neighbors = training_samples[:k]

    # 统计标签频率
    label_counts = {}
    for neighbor in k_neighbors:
        label = neighbor['label']
        label_counts[label] = label_counts.get(label, 0) + 1

    if not label_counts:
        # 如果k=0或没有邻居,虽然题目约束不会发生,但做好保护
        return

    # 找出最高频率
    max_freq = 0
    for freq in label_counts.values():
        if freq > max_freq:
            max_freq = freq
    
    # 找出所有最高频率的标签
    tied_labels = [label for label, freq in label_counts.items() if freq == max_freq]

    predicted_label = -1
    if len(tied_labels) == 1:
        predicted_label = tied_labels[0]
    else:
        # 并列处理
        for neighbor in k_neighbors:
            if neighbor['label'] in tied_labels:
                predicted_label = neighbor['label']
                break
    
    print(f"{predicted_label} {max_freq}")

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:K-近邻 (K-Nearest Neighbors, KNN)
  • 时间复杂度:,其中 是训练样本数, 是特征维度。
    • 计算所有样本与待测点的距离需要
    • 个样本按距离排序需要
    • 统计和并列处理需要 ,其中
    • 因此,总时间复杂度由距离计算和排序主导。
  • 空间复杂度:,主要用于存储所有训练样本的特征数据。