均衡版 KMeans 分群与新用户归类
题目分析
给定 位客户的
维特征向量,使用均衡版 KMeans 将其划分为
个群组,每组容量严格均衡(
或
),迭代至收敛后,对新客户进行归类。
思路
模拟均衡 KMeans 聚类
与标准 KMeans 不同,本题要求每个簇的容量严格均衡,因此分配阶段需要加入容量限制。
初始化:取前 个客户的特征作为初始质心。容量分配为:前
个质心容量为
,其余为
。
分配阶段:按客户编号 到
顺序处理。对每个客户,计算到所有未满员质心的平方欧氏距离,选距离最小的质心(并列选编号小的)。
更新阶段:每个质心更新为该簇内所有成员特征的逐维均值向下取整。
收敛判断:若本轮的分配结果和质心与上一轮完全一致,则停止迭代。
新用户归类:将最终质心按字典序排序,计算新客户到各质心的平方欧氏距离,选最近的质心(并列选字典序最小的),输出其在排序列表中的位置(1-indexed)。
关键细节:
- 距离使用平方欧氏距离,避免浮点误差。
- 质心更新使用整数除法向下取整,因此质心始终是整数,收敛判断可以直接比较。
- 容量限制使得分配顺序会影响结果,必须按客户编号顺序处理。
代码
import sys
def main():
data = sys.stdin.read().split()
idx = 0
N = int(data[idx]); idx += 1
M = int(data[idx]); idx += 1
K = int(data[idx]); idx += 1
customers = []
for i in range(N):
feat = []
for j in range(M):
feat.append(int(data[idx])); idx += 1
customers.append(feat)
new_cust = []
for j in range(M):
new_cust.append(int(data[idx])); idx += 1
centers = [list(customers[i]) for i in range(K)]
base = N // K
extra = N % K
caps = [base + 1 if i < extra else base for i in range(K)]
def sq_dist(a, b):
s = 0
for i in range(M):
d = a[i] - b[i]
s += d * d
return s
prev_assign = None
prev_centers = None
while True:
assign = [[] for _ in range(K)]
filled = [0] * K
assignment = [0] * N
for ci in range(N):
best_center = -1
best_dist = -1
for ki in range(K):
if filled[ki] >= caps[ki]:
continue
d = sq_dist(customers[ci], centers[ki])
if best_center == -1 or d < best_dist:
best_dist = d
best_center = ki
assign[best_center].append(ci)
filled[best_center] += 1
assignment[ci] = best_center
new_centers = []
for ki in range(K):
nc = []
for j in range(M):
s = sum(customers[ci][j] for ci in assign[ki])
nc.append(s // len(assign[ki]))
new_centers.append(nc)
if assignment == prev_assign and new_centers == prev_centers:
centers = new_centers
break
prev_assign = assignment
prev_centers = new_centers
centers = new_centers
sorted_centers = sorted(centers)
out = []
for c in sorted_centers:
out.append(' '.join(map(str, c)))
best_idx = -1
best_dist = -1
for i, c in enumerate(sorted_centers):
d = sq_dist(new_cust, c)
if best_idx == -1 or d < best_dist or (d == best_dist and c < sorted_centers[best_idx]):
best_dist = d
best_idx = i
out.append(str(best_idx + 1))
sys.stdout.write('\n'.join(out) + '\n')
main()
复杂度分析
- 时间复杂度:
,其中
为迭代轮数。每轮对
个客户计算到
个质心的
维距离。由于质心取整,收敛通常很快。
- 空间复杂度:
,存储客户特征和质心。

京公网安备 11010502036488号