题目链接

用户分群

题目描述

某电商平台希望根据用户的三个特征指标:月均消费金额()、月均访问次数()和归一化后的退货率(),对用户进行分群。 你需要实现 KMeans 聚类算法。给定 个初始聚类中心和 个数据点,按照以下流程迭代:

  1. 将每个数据点分配到距离最近的聚类中心所在的组(使用欧氏距离)。
  2. 对每个组重新计算中心点(组内所有点各维度的算术平均值)。 重复上述过程指定的迭代次数后,输出最终的 个聚类中心,每个维度的值保留两位小数(四舍五入)。

欧氏距离公式:

解题思路

本题要求直接模拟 KMeans 算法的迭代过程。

  1. 数据结构: 每个点或中心点可以使用包含三个浮点数的结构或数组表示。

  2. 聚类分配: 在每一轮迭代中,遍历所有 个数据点。对于每个点,计算它到 个当前中心的欧氏距离。为了减少开销,比较距离时可以比较距离的平方值 。将其归类到距离最小的中心所属的簇。

  3. 中心更新: 分配完成后,遍历每个簇。如果簇内有数据点,计算这些点的坐标平均值作为新的聚类中心。如果某个簇为空(虽然在样例中未出现),通常保持原中心不变。

  4. 迭代与输出: 执行指定次数的迭代。最后按顺序输出 个中心点。格式化输出时,C++ 使用 fixedsetprecision(2),Java 使用 String.format("%.2f"),Python 使用 format(num, ".2f")

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <iomanip>

using namespace std;

// 定义三维空间中的点
struct Point {
    double x, y, z;
};

// 计算两个点之间的欧氏距离的平方
double get_dist_sq(const Point& a, const Point& b) {
    return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y) + (a.z - b.z) * (a.z - b.z);
}

int main() {
    int k;
    cin >> k;
    vector<Point> centers(k);
    for (int i = 0; i < k; ++i) {
        cin >> centers[i].x >> centers[i].y >> centers[i].z;
    }

    int iterations;
    cin >> iterations;
    int m;
    cin >> m;
    vector<Point> data(m);
    for (int i = 0; i < m; ++i) {
        cin >> data[i].x >> data[i].y >> data[i].z;
    }

    // KMeans 迭代过程
    for (int it = 0; it < iterations; ++it) {
        vector<vector<Point>> clusters(k);
        for (int i = 0; i < m; ++i) {
            int best_idx = 0;
            double min_dist_sq = get_dist_sq(data[i], centers[0]);
            for (int j = 1; j < k; ++j) {
                double d_sq = get_dist_sq(data[i], centers[j]);
                if (d_sq < min_dist_sq) {
                    min_dist_sq = d_sq;
                    best_idx = j;
                }
            }
            clusters[best_idx].push_back(data[i]);
        }

        // 更新中心点
        for (int i = 0; i < k; ++i) {
            if (clusters[i].empty()) continue;
            double sum_x = 0, sum_y = 0, sum_z = 0;
            for (const auto& p : clusters[i]) {
                sum_x += p.x;
                sum_y += p.y;
                sum_z += p.z;
            }
            int count = (int)clusters[i].size();
            centers[i] = {sum_x / count, sum_y / count, sum_z / count};
        }
    }

    // 输出最终中心点
    for (int i = 0; i < k; ++i) {
        cout << fixed << setprecision(2) << centers[i].x << " " 
             << centers[i].y << " " << centers[i].z << endl;
    }

    return 0;
}
import java.util.*;

public class Main {
    static class Point {
        double x, y, z;
        Point(double x, double y, double z) {
            this.x = x;
            this.y = y;
            this.z = z;
        }
    }

    static double getDistSq(Point a, Point b) {
        return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y) + (a.z - b.z) * (a.z - b.z);
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in).useLocale(Locale.US);
        int k = sc.nextInt();
        Point[] centers = new Point[k];
        for (int i = 0; i < k; i++) {
            centers[i] = new Point(sc.nextDouble(), sc.nextDouble(), sc.nextDouble());
        }

        int iterations = sc.nextInt();
        int m = sc.nextInt();
        Point[] data = new Point[m];
        for (int i = 0; i < m; i++) {
            data[i] = new Point(sc.nextDouble(), sc.nextDouble(), sc.nextDouble());
        }

        for (int it = 0; it < iterations; it++) {
            List<Point>[] clusters = new ArrayList[k];
            for (int i = 0; i < k; i++) clusters[i] = new ArrayList<>();

            for (int i = 0; i < m; i++) {
                int bestIdx = 0;
                double minDistSq = getDistSq(data[i], centers[0]);
                for (int j = 1; j < k; j++) {
                    double dSq = getDistSq(data[i], centers[j]);
                    if (dSq < minDistSq) {
                        minDistSq = dSq;
                        bestIdx = j;
                    }
                }
                clusters[bestIdx].add(data[i]);
            }

            for (int i = 0; i < k; i++) {
                if (clusters[i].isEmpty()) continue;
                double sumX = 0, sumY = 0, sumZ = 0;
                for (Point p : clusters[i]) {
                    sumX += p.x;
                    sumY += p.y;
                    sumZ += p.z;
                }
                int count = clusters[i].size();
                centers[i] = new Point(sumX / count, sumY / count, sumZ / count);
            }
        }

        for (int i = 0; i < k; i++) {
            System.out.println(String.format(Locale.US, "%.2f %.2f %.2f", centers[i].x, centers[i].y, centers[i].z));
        }
    }
}
def solve():
    import sys

    # 读取 K
    k = int(input())
    centers = []
    for _ in range(k):
        centers.append(list(map(float, input().split())))

    # 读取迭代次数和数据点个数
    iterations = int(input())
    m = int(input())
    data_points = []
    for _ in range(m):
        data_points.append(list(map(float, input().split())))

    def get_dist_sq(p1, p2):
        return (p1[0] - p2[0])**2 + (p1[1] - p2[1])**2 + (p1[2] - p2[2])**2

    # KMeans 迭代过程
    for _ in range(iterations):
        clusters = [[] for _ in range(k)]
        for p in data_points:
            best_idx = 0
            min_dist_sq = get_dist_sq(p, centers[0])
            for j in range(1, k):
                d_sq = get_dist_sq(p, centers[j])
                if d_sq < min_dist_sq:
                    min_dist_sq = d_sq
                    best_idx = j
            clusters[best_idx].append(p)

        # 更新中心点
        for i in range(k):
            if not clusters[i]:
                continue
            sum_x = sum(p[0] for p in clusters[i])
            sum_y = sum(p[1] for p in clusters[i])
            sum_z = sum(p[2] for p in clusters[i])
            count = len(clusters[i])
            centers[i] = [sum_x / count, sum_y / count, sum_z / count]

    # 输出最终中心点
    for c in centers:
        print(f"{c[0]:.2f} {c[1]:.2f} {c[2]:.2f}")

solve()

算法及复杂度

  • 算法:KMeans 聚类模拟。
  • 时间复杂度:。其中 为迭代次数, 为数据点数, 为聚类中心数。
  • 空间复杂度:,用于存储数据点和聚类中心。