题目链接

均衡版 KMeans 分群与新用户归类

题目描述

某电商平台需要将 位老客户根据其 维非负整数特征,划分为 个群组。为了确保资源分配公平,群组容量需要严格均衡。你需要实现一个该平台定制的 KMeans 变体算法,并用最终形成的群组中心(centroids)对一个新客户进行分类。

算法规则详情:

  1. 初始中心:取输入数据的前 位客户的特征向量作为初始的 个中心点。

  2. 群组容量:每个群组的容量是固定的。前 个群组的容量为 ,其余群组的容量为

  3. 分配阶段

    • 每一轮分配都严格按照客户的输入顺序(从第 1 位到第 位)进行。
    • 对每个客户,计算其与所有中心的欧氏距离(为简化计算,可直接使用距离的平方和进行比较)。
    • 从所有尚未满员的群组中,为该客户选择一个距离最近的中心。
    • 平局打破:如果到多个中心的距离相等,则选择中心编号更小的那个。
  4. 更新阶段

    • 一轮分配(所有 个客户都被分配完毕)后,需要更新每个中心的坐标。
    • 新的中心坐标是该群组内所有成员特征向量的逐维平均值,并且对平均值向下取整 (floor)
  5. 收敛条件

    • 如果在一轮迭代之后,所有客户的分配结果以及所有中心的坐标与上一轮完全相同,则算法收敛,迭代停止。
  6. 输出与新客户分类

    • 算法收敛后,将最终的 个中心点按字典序(优先比较第一维,其次第二维,以此类推)进行升序排序,并输出。
    • 接着,给定一个新客户的特征,计算其到所有已排序的最终中心的距离。
    • 将其归类到距离最近的中心。若距离有并列,则选择在排序后列表中字典序最小(即索引靠前)的中心。
    • 最后,输出新客户所属中心在排序后列表中的序号(从 1 开始)

输入描述:

  • 第一行:
  • 接下来 行: 位老客户的特征。
  • 最后一行:新客户的特征。

解题思路

本题是对 KMeans 算法的深度定制,解题关键在于精确实现其独特的规则。

  1. 初始化

    • 读取所有客户数据和新客户数据。
    • 根据规则 1,将前 个客户的特征复制为初始中心点。
    • 根据规则 2,计算并存储每个群组的固定容量。
  2. 迭代聚类

    • 使用一个主循环来不断进行“分配-更新”的迭代过程,直到满足收敛条件。
    • 在每轮迭代开始前,需要深拷贝当前的中心点和客户分配情况,用于迭代结束后的收敛性检查。
  3. 分配阶段(核心)

    • 清空上一轮的群组成员信息,并准备一个数组记录各群组当前已分配的人数。
    • 严格按照从 的顺序遍历所有客户。
    • 对每个客户,遍历所有 个中心点:
      • 筛选出所有容量未满的中心。
      • 在这些候选中心里,找到距离该客户最近的。使用平方欧氏距离避免开方运算。
      • 由于中心点是按编号 顺序遍历的,因此在距离相同时,最先找到的那个自然就是编号最小的,这巧妙地满足了平局规则。
    • 将客户分配给找到的最佳群组,并更新该群组的当前人数。
  4. 更新阶段

    • 为每个群组计算新的中心点。
    • 遍历群组内的所有客户,逐维累加他们的特征值。
    • 将累加和除以群组内的客户总数得到平均值,然后对每一维都进行向下取整,得到新的中心点坐标。
  5. 收敛判断

    • 比较本轮迭代产生的新中心点、新分配情况与迭代前保存的状态。只有当两者都完全相同时,才设置收敛标志,退出主循环。
  6. 输出与分类

    • 迭代结束后,对得到的最终 个中心点进行字典序排序。
    • 输出排序后的中心点。
    • 计算新客户到每一个已排序中心点的距离,找到距离最近的中心。由于中心已排序,遍历时第一个遇到的最近中心即满足“字典序最小”的平局规则。
    • 输出该中心的 1-based 索引。

数据结构注意

  • 需要使用 long 类型来计算平方距离和特征累加和,以防止整数溢出。
  • 在检查收敛时,务必使用深拷贝(deep copy)来比较前后状态,避免因引用相同对象而导致判断失效。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <algorithm>

using namespace std;

long long squared_distance(const vector<long long>& p1, const vector<long long>& p2) {
    long long dist = 0;
    for (size_t i = 0; i < p1.size(); ++i) {
        dist += (p1[i] - p2[i]) * (p1[i] - p2[i]);
    }
    return dist;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m, k;
    cin >> n >> m >> k;

    vector<vector<long long>> customers(n, vector<long long>(m));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < m; ++j) {
            cin >> customers[i][j];
        }
    }

    vector<long long> new_customer(m);
    for (int i = 0; i < m; ++i) {
        cin >> new_customer[i];
    }

    vector<vector<long long>> centroids(k, vector<long long>(m));
    for (int i = 0; i < k; ++i) {
        centroids[i] = customers[i];
    }

    vector<int> capacities(k);
    int base_capacity = n / k;
    int remainder = n % k;
    for (int i = 0; i < k; ++i) {
        capacities[i] = base_capacity + (i < remainder ? 1 : 0);
    }

    vector<int> assignments(n, -1);
    bool converged = false;

    while (!converged) {
        vector<vector<long long>> prev_centroids = centroids;
        vector<int> prev_assignments = assignments;

        vector<vector<int>> clusters(k);
        vector<int> current_sizes(k, 0);

        for (int i = 0; i < n; ++i) {
            long long min_dist = -1;
            int best_cluster = -1;
            for (int j = 0; j < k; ++j) {
                if (current_sizes[j] < capacities[j]) {
                    long long dist = squared_distance(customers[i], centroids[j]);
                    if (best_cluster == -1 || dist < min_dist) {
                        min_dist = dist;
                        best_cluster = j;
                    }
                }
            }
            clusters[best_cluster].push_back(i);
            current_sizes[best_cluster]++;
            assignments[i] = best_cluster;
        }

        for (int i = 0; i < k; ++i) {
            if (!clusters[i].empty()) {
                vector<long long> sum(m, 0);
                for (int customer_idx : clusters[i]) {
                    for (int j = 0; j < m; ++j) {
                        sum[j] += customers[customer_idx][j];
                    }
                }
                for (int j = 0; j < m; ++j) {
                    centroids[i][j] = floor((double)sum[j] / clusters[i].size());
                }
            }
        }
        
        if (centroids == prev_centroids && assignments == prev_assignments) {
            converged = true;
        }
    }

    sort(centroids.begin(), centroids.end());

    for (int i = 0; i < k; ++i) {
        for (int j = 0; j < m; ++j) {
            cout << centroids[i][j] << (j == m - 1 ? "" : " ");
        }
        cout << endl;
    }

    long long min_dist = -1;
    int best_cluster_idx = -1;
    for (int i = 0; i < k; ++i) {
        long long dist = squared_distance(new_customer, centroids[i]);
        if (best_cluster_idx == -1 || dist < min_dist) {
            min_dist = dist;
            best_cluster_idx = i;
        }
    }

    cout << best_cluster_idx + 1 << endl;

    return 0;
}
import java.util.*;
import java.lang.Math;

public class Main {

    private static long squaredDistance(long[] p1, long[] p2) {
        long dist = 0;
        for (int i = 0; i < p1.length; i++) {
            dist += (p1[i] - p2[i]) * (p1[i] - p2[i]);
        }
        return dist;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int k = sc.nextInt();

        long[][] customers = new long[n][m];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                customers[i][j] = sc.nextLong();
            }
        }

        long[] newCustomer = new long[m];
        for (int i = 0; i < m; i++) {
            newCustomer[i] = sc.nextLong();
        }

        long[][] centroids = new long[k][m];
        for (int i = 0; i < k; i++) {
            centroids[i] = Arrays.copyOf(customers[i], m);
        }

        int[] capacities = new int[k];
        int baseCapacity = n / k;
        int remainder = n % k;
        for (int i = 0; i < k; i++) {
            capacities[i] = baseCapacity + (i < remainder ? 1 : 0);
        }

        int[] assignments = new int[n];
        Arrays.fill(assignments, -1);
        boolean converged = false;

        while (!converged) {
            long[][] prevCentroids = new long[k][];
            for(int i=0; i<k; i++) prevCentroids[i] = Arrays.copyOf(centroids[i], m);
            int[] prevAssignments = Arrays.copyOf(assignments, n);

            List<List<Integer>> clusters = new ArrayList<>();
            for (int i = 0; i < k; i++) clusters.add(new ArrayList<>());
            int[] currentSizes = new int[k];

            for (int i = 0; i < n; i++) {
                long minDis = -1;
                int bestCluster = -1;
                for (int j = 0; j < k; j++) {
                    if (currentSizes[j] < capacities[j]) {
                        long dist = squaredDistance(customers[i], centroids[j]);
                        if (bestCluster == -1 || dist < minDis) {
                            minDis = dist;
                            bestCluster = j;
                        }
                    }
                }
                clusters.get(bestCluster).add(i);
                currentSizes[bestCluster]++;
                assignments[i] = bestCluster;
            }
            
            for (int i = 0; i < k; i++) {
                if (!clusters.get(i).isEmpty()) {
                    long[] sum = new long[m];
                    for (int customerIdx : clusters.get(i)) {
                        for (int j = 0; j < m; j++) {
                            sum[j] += customers[customerIdx][j];
                        }
                    }
                    for (int j = 0; j < m; j++) {
                        centroids[i][j] = (long) Math.floor((double) sum[j] / clusters.get(i).size());
                    }
                }
            }

            if (Arrays.deepEquals(centroids, prevCentroids) && Arrays.equals(assignments, prevAssignments)) {
                converged = true;
            }
        }

        Arrays.sort(centroids, (a, b) -> {
            for (int i = 0; i < m; i++) {
                if (a[i] != b[i]) {
                    return Long.compare(a[i], b[i]);
                }
            }
            return 0;
        });

        for (int i = 0; i < k; i++) {
            for (int j = 0; j < m; j++) {
                System.out.print(centroids[i][j] + (j == m - 1 ? "" : " "));
            }
            System.out.println();
        }

        long minDis = -1;
        int bestClusterIdx = -1;
        for (int i = 0; i < k; i++) {
            long dist = squaredDistance(newCustomer, centroids[i]);
            if (bestClusterIdx == -1 || dist < minDis) {
                minDis = dist;
                bestClusterIdx = i;
            }
        }
        System.out.println(bestClusterIdx + 1);
    }
}
import math

def squared_distance(p1, p2):
    return sum((x - y) ** 2 for x, y in zip(p1, p2))

def solve():
    n, m, k = map(int, input().split())
    customers = [list(map(int, input().split())) for _ in range(n)]
    new_customer = list(map(int, input().split()))

    centroids = [customers[i] for i in range(k)]

    capacities = [n // k] * k
    for i in range(n % k):
        capacities[i] += 1
    
    assignments = [-1] * n
    converged = False

    while not converged:
        prev_centroids = [list(c) for c in centroids]
        prev_assignments = list(assignments)

        clusters = [[] for _ in range(k)]
        current_sizes = [0] * k

        for i in range(n):
            min_dist = -1
            best_cluster = -1
            for j in range(k):
                if current_sizes[j] < capacities[j]:
                    dist = squared_distance(customers[i], centroids[j])
                    if best_cluster == -1 or dist < min_dist:
                        min_dist = dist
                        best_cluster = j
            
            clusters[best_cluster].append(i)
            current_sizes[best_cluster] += 1
            assignments[i] = best_cluster

        for i in range(k):
            if clusters[i]:
                sum_features = [0] * m
                for customer_idx in clusters[i]:
                    for j in range(m):
                        sum_features[j] += customers[customer_idx][j]
                
                centroids[i] = [math.floor(s / len(clusters[i])) for s in sum_features]

        if centroids == prev_centroids and assignments == prev_assignments:
            converged = True

    centroids.sort()

    for center in centroids:
        print(*center)

    min_dist = -1
    best_cluster_idx = -1
    for i in range(k):
        dist = squared_distance(new_customer, centroids[i])
        if best_cluster_idx == -1 or dist < min_dist:
            min_dist = dist
            best_cluster_idx = i

    print(best_cluster_idx + 1)

solve()

算法及复杂度

  • 算法:本题实现了一个高度定制化的 KMeans 聚类算法。其核心是一个迭代过程,每轮迭代包括“按顺序和容量限制分配”和“更新中心点”两个阶段,直至收敛。
  • 时间复杂度,其中 是收敛所需的迭代次数。
    • 在每次迭代中,分配阶段需要为 个客户中的每一个计算到 个中心的距离,每个距离计算耗时 ,总计
    • 更新阶段需要遍历所有 个客户一次来计算新的中心,总计
    • 因此,单次迭代的复杂度由分配阶段主导,为
    • 最终的中心排序和新客户分类的复杂度()远小于迭代部分,可以忽略不计。
  • 空间复杂度
    • 主要空间开销用于存储 个客户的特征数据 () 和 个中心的特征数据 ()。
    • 在迭代过程中还需要存储每个客户的分配结果和每个群组的成员列表,总空间为 ,这通常远小于特征数据的存储。