二分 K-means 子网分割
题意
给定 个二维坐标点,要求用"二分 K-means"(Bisecting K-means)算法将它们分成
个簇。算法从一个包含全部点的簇出发,每次选 SSE(簇内各点到簇心的欧氏距离平方和)最大的簇进行二分裂,直到簇的数量达到
。每次二分裂后,输出当前所有簇的大小(降序排列)。
思路
这道题没有什么算法上的巧思,核心就是把题意里的二分 K-means 流程忠实地模拟出来。关键在于把每个细节对齐:
1. 选哪个簇来分裂?
每次选 SSE 最大的簇。SSE 就是簇内每个点到簇心(平均坐标)的欧氏距离的平方和:
$$
2. 怎么初始化二分裂的两个簇心?
在要分裂的簇内,找 坐标最小的点作为簇心 A,
坐标最大的点作为簇心 B。如果有多个点
坐标相同,按
坐标再按输入顺序来打破并列。
3. 迭代分配过程?
标准 K-means 迭代:每个点按欧氏距离分给最近的簇心(距离相同时分给第一个簇心),然后用簇内均值更新簇心。如果某个簇变空了,保持它的簇心不变继续迭代。
收敛条件:两个簇心的总移动量小于 ,或者迭代次数达到 1000 次。
4. 输出什么?
一共 行,第
行输出第
次二分裂后所有簇的大小,降序排列。
整体流程就是一个循环:找最大 SSE 簇 → 二分裂 → 记录结果,重复 次。
复杂度
- 时间:
,其中
是 K-means 迭代次数(最多 1000)
- 空间:
代码
import sys
def solve():
data = sys.stdin.read().split()
idx = 0
N = int(data[idx]); idx += 1
M = int(data[idx]); idx += 1
points = []
for i in range(M):
x = int(data[idx]); idx += 1
y = int(data[idx]); idx += 1
points.append((x, y, i))
def compute_sse(cluster):
if not cluster:
return 0.0
cx = sum(p[0] for p in cluster) / len(cluster)
cy = sum(p[1] for p in cluster) / len(cluster)
return sum((p[0] - cx) ** 2 + (p[1] - cy) ** 2 for p in cluster)
def bisect(cluster):
if len(cluster) <= 1:
return [cluster]
# 初始簇心:x最小点 和 x最大点
min_pt = min(cluster, key=lambda p: (p[0], p[1], p[2]))
max_pt = max(cluster, key=lambda p: (p[0], -p[1], -p[2]))
c1 = [min_pt[0], min_pt[1]]
c2 = [max_pt[0], max_pt[1]]
for _ in range(1000):
g1, g2 = [], []
for p in cluster:
d1 = (p[0] - c1[0]) ** 2 + (p[1] - c1[1]) ** 2
d2 = (p[0] - c2[0]) ** 2 + (p[1] - c2[1]) ** 2
if d1 <= d2:
g1.append(p)
else:
g2.append(p)
nc1, nc2 = list(c1), list(c2)
if g1:
nc1 = [sum(p[0] for p in g1) / len(g1),
sum(p[1] for p in g1) / len(g1)]
if g2:
nc2 = [sum(p[0] for p in g2) / len(g2),
sum(p[1] for p in g2) / len(g2)]
move = ((nc1[0]-c1[0])**2 + (nc1[1]-c1[1])**2)**0.5 + \
((nc2[0]-c2[0])**2 + (nc2[1]-c2[1])**2)**0.5
c1, c2 = nc1, nc2
if move < 1e-6:
break
return [g1, g2]
clusters = [list(points)]
res = []
for _ in range(N - 1):
best_i = max(range(len(clusters)), key=lambda i: compute_sse(clusters[i]))
target = clusters.pop(best_i)
clusters.extend(bisect(target))
sizes = sorted([len(c) for c in clusters], reverse=True)
res.append(' '.join(map(str, sizes)))
print('\n'.join(res))
solve()

京公网安备 11010502036488号