K-Means聚类下的Anchor优化输出
题意
给定 个矩形框(宽和高),用基于 IOU 距离的 K-Means 聚类算法选出
个代表性 anchor 尺寸。初始中心取前
个框,迭代最多
次,最终按面积降序输出
个 anchor。
思路
IOU 距离是什么?
在目标检测里,两个框有多"相似"通常用 IOU(交并比)来衡量。这里的框都以左上角对齐来计算:
$$
距离就是 。IOU 越大,距离越小,说明两个框越像。
K-Means 流程
K-Means 本身的套路你应该很熟了,只是这里距离函数换成了 IOU 距离:
- 初始化:拿前
个框作为初始聚类中心。
- 分配:对每个框,算它到
个中心的 IOU 距离,归入最近的那个簇。
- 更新:每个簇内所有框的宽求平均、高求平均,分别向下取整得到新中心。如果某个簇空了,保留原中心不变。
- 收敛判断:算所有中心的位移之和(欧几里得距离),如果总位移
就停止;否则最多迭代
次。
细节别踩坑
有几个容易忽略的点:
- 向下取整:题目说"取整",实际是 floor,不是四舍五入。Python 里用
math.floor(sum_w / cnt)即可。 - 空簇处理:如果某次迭代后某个簇没有分到任何框,这个簇的中心保持上一轮的值。
- 排序规则:最后按面积降序输出。面积相同时按宽降序,再按高降序。
复杂度
每轮迭代对 个框各算
次 IOU,最多
轮,时间
。题目数据规模不大,直接模拟即可。
代码
import sys
import math
def iou(w1, h1, w2, h2):
inter = min(w1, w2) * min(h1, h2)
union = w1 * h1 + w2 * h2 - inter
if union == 0:
return 0.0
return inter / union
def solve():
data = sys.stdin.read().split()
idx = 0
N = int(data[idx]); idx += 1
K = int(data[idx]); idx += 1
T = int(data[idx]); idx += 1
boxes = []
for i in range(N):
w = int(data[idx]); idx += 1
h = int(data[idx]); idx += 1
boxes.append((w, h))
# 初始中心:前 K 个框
centers = [(float(boxes[i][0]), float(boxes[i][1])) for i in range(K)]
for iteration in range(T):
# 分配:每个框归入 IOU 距离最近的簇
clusters = [[] for _ in range(K)]
for bw, bh in boxes:
best_k = 0
best_dist = float('inf')
for k in range(K):
cw, ch = centers[k]
d = 1.0 - iou(bw, bh, cw, ch)
if d < best_dist:
best_dist = d
best_k = k
clusters[best_k].append((bw, bh))
# 更新中心
new_centers = []
for k in range(K):
if len(clusters[k]) == 0:
new_centers.append(centers[k])
else:
sum_w = sum(b[0] for b in clusters[k])
sum_h = sum(b[1] for b in clusters[k])
cnt = len(clusters[k])
nw = math.floor(sum_w / cnt)
nh = math.floor(sum_h / cnt)
new_centers.append((float(nw), float(nh)))
# 收敛判断:总位移 < 1e-4
total_disp = 0.0
for k in range(K):
dw = new_centers[k][0] - centers[k][0]
dh = new_centers[k][1] - centers[k][1]
total_disp += math.sqrt(dw * dw + dh * dh)
centers = new_centers
if total_disp < 1e-4:
break
# 按面积降序输出,面积相同按宽降序、高降序
result = [(int(cw), int(ch)) for cw, ch in centers]
result.sort(key=lambda x: (x[0] * x[1], x[0], x[1]), reverse=True)
for w, h in result:
print(w, h)
solve()

京公网安备 11010502036488号