题目链接
题目描述
某电商平台希望根据用户的三个特征指标:月均消费金额()、月均访问次数(
)和归一化后的退货率(
),对用户进行分群。
你需要实现 KMeans 聚类算法。给定
个初始聚类中心和
个数据点,按照以下流程迭代:
- 将每个数据点分配到距离最近的聚类中心所在的组(使用欧氏距离)。
- 对每个组重新计算中心点(组内所有点各维度的算术平均值)。
重复上述过程指定的迭代次数后,输出最终的
个聚类中心,每个维度的值保留两位小数(四舍五入)。
欧氏距离公式:。
解题思路
本题要求直接模拟 KMeans 算法的迭代过程。
-
数据结构: 每个点或中心点可以使用包含三个浮点数的结构或数组表示。
-
聚类分配: 在每一轮迭代中,遍历所有
个数据点。对于每个点,计算它到
个当前中心的欧氏距离。为了减少开销,比较距离时可以比较距离的平方值
。将其归类到距离最小的中心所属的簇。
-
中心更新: 分配完成后,遍历每个簇。如果簇内有数据点,计算这些点的坐标平均值作为新的聚类中心。如果某个簇为空(虽然在样例中未出现),通常保持原中心不变。
-
迭代与输出: 执行指定次数的迭代。最后按顺序输出
个中心点。格式化输出时,C++ 使用
fixed和setprecision(2),Java 使用String.format("%.2f"),Python 使用format(num, ".2f")。
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <iomanip>
using namespace std;
// 定义三维空间中的点
struct Point {
double x, y, z;
};
// 计算两个点之间的欧氏距离的平方
double get_dist_sq(const Point& a, const Point& b) {
return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y) + (a.z - b.z) * (a.z - b.z);
}
int main() {
int k;
cin >> k;
vector<Point> centers(k);
for (int i = 0; i < k; ++i) {
cin >> centers[i].x >> centers[i].y >> centers[i].z;
}
int iterations;
cin >> iterations;
int m;
cin >> m;
vector<Point> data(m);
for (int i = 0; i < m; ++i) {
cin >> data[i].x >> data[i].y >> data[i].z;
}
// KMeans 迭代过程
for (int it = 0; it < iterations; ++it) {
vector<vector<Point>> clusters(k);
for (int i = 0; i < m; ++i) {
int best_idx = 0;
double min_dist_sq = get_dist_sq(data[i], centers[0]);
for (int j = 1; j < k; ++j) {
double d_sq = get_dist_sq(data[i], centers[j]);
if (d_sq < min_dist_sq) {
min_dist_sq = d_sq;
best_idx = j;
}
}
clusters[best_idx].push_back(data[i]);
}
// 更新中心点
for (int i = 0; i < k; ++i) {
if (clusters[i].empty()) continue;
double sum_x = 0, sum_y = 0, sum_z = 0;
for (const auto& p : clusters[i]) {
sum_x += p.x;
sum_y += p.y;
sum_z += p.z;
}
int count = (int)clusters[i].size();
centers[i] = {sum_x / count, sum_y / count, sum_z / count};
}
}
// 输出最终中心点
for (int i = 0; i < k; ++i) {
cout << fixed << setprecision(2) << centers[i].x << " "
<< centers[i].y << " " << centers[i].z << endl;
}
return 0;
}
import java.util.*;
public class Main {
static class Point {
double x, y, z;
Point(double x, double y, double z) {
this.x = x;
this.y = y;
this.z = z;
}
}
static double getDistSq(Point a, Point b) {
return (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y) + (a.z - b.z) * (a.z - b.z);
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in).useLocale(Locale.US);
int k = sc.nextInt();
Point[] centers = new Point[k];
for (int i = 0; i < k; i++) {
centers[i] = new Point(sc.nextDouble(), sc.nextDouble(), sc.nextDouble());
}
int iterations = sc.nextInt();
int m = sc.nextInt();
Point[] data = new Point[m];
for (int i = 0; i < m; i++) {
data[i] = new Point(sc.nextDouble(), sc.nextDouble(), sc.nextDouble());
}
for (int it = 0; it < iterations; it++) {
List<Point>[] clusters = new ArrayList[k];
for (int i = 0; i < k; i++) clusters[i] = new ArrayList<>();
for (int i = 0; i < m; i++) {
int bestIdx = 0;
double minDistSq = getDistSq(data[i], centers[0]);
for (int j = 1; j < k; j++) {
double dSq = getDistSq(data[i], centers[j]);
if (dSq < minDistSq) {
minDistSq = dSq;
bestIdx = j;
}
}
clusters[bestIdx].add(data[i]);
}
for (int i = 0; i < k; i++) {
if (clusters[i].isEmpty()) continue;
double sumX = 0, sumY = 0, sumZ = 0;
for (Point p : clusters[i]) {
sumX += p.x;
sumY += p.y;
sumZ += p.z;
}
int count = clusters[i].size();
centers[i] = new Point(sumX / count, sumY / count, sumZ / count);
}
}
for (int i = 0; i < k; i++) {
System.out.println(String.format(Locale.US, "%.2f %.2f %.2f", centers[i].x, centers[i].y, centers[i].z));
}
}
}
def solve():
import sys
# 读取 K
k = int(input())
centers = []
for _ in range(k):
centers.append(list(map(float, input().split())))
# 读取迭代次数和数据点个数
iterations = int(input())
m = int(input())
data_points = []
for _ in range(m):
data_points.append(list(map(float, input().split())))
def get_dist_sq(p1, p2):
return (p1[0] - p2[0])**2 + (p1[1] - p2[1])**2 + (p1[2] - p2[2])**2
# KMeans 迭代过程
for _ in range(iterations):
clusters = [[] for _ in range(k)]
for p in data_points:
best_idx = 0
min_dist_sq = get_dist_sq(p, centers[0])
for j in range(1, k):
d_sq = get_dist_sq(p, centers[j])
if d_sq < min_dist_sq:
min_dist_sq = d_sq
best_idx = j
clusters[best_idx].append(p)
# 更新中心点
for i in range(k):
if not clusters[i]:
continue
sum_x = sum(p[0] for p in clusters[i])
sum_y = sum(p[1] for p in clusters[i])
sum_z = sum(p[2] for p in clusters[i])
count = len(clusters[i])
centers[i] = [sum_x / count, sum_y / count, sum_z / count]
# 输出最终中心点
for c in centers:
print(f"{c[0]:.2f} {c[1]:.2f} {c[2]:.2f}")
solve()
算法及复杂度
- 算法:KMeans 聚类模拟。
- 时间复杂度:
。其中
为迭代次数,
为数据点数,
为聚类中心数。
- 空间复杂度:
,用于存储数据点和聚类中心。

京公网安备 11010502036488号