K-Means聚类下的Anchor优化输出

题意

给定 个矩形框(宽和高),用基于 IOU 距离的 K-Means 聚类算法选出 个代表性 anchor 尺寸。初始中心取前 个框,迭代最多 次,最终按面积降序输出 个 anchor。

思路

IOU 距离是什么?

在目标检测里,两个框有多"相似"通常用 IOU(交并比)来衡量。这里的框都以左上角对齐来计算:

$$

距离就是 。IOU 越大,距离越小,说明两个框越像。

K-Means 流程

K-Means 本身的套路你应该很熟了,只是这里距离函数换成了 IOU 距离:

  1. 初始化:拿前 个框作为初始聚类中心。
  2. 分配:对每个框,算它到 个中心的 IOU 距离,归入最近的那个簇。
  3. 更新:每个簇内所有框的宽求平均、高求平均,分别向下取整得到新中心。如果某个簇空了,保留原中心不变。
  4. 收敛判断:算所有中心的位移之和(欧几里得距离),如果总位移 就停止;否则最多迭代 次。

细节别踩坑

有几个容易忽略的点:

  • 向下取整:题目说"取整",实际是 floor,不是四舍五入。Python 里用 math.floor(sum_w / cnt) 即可。
  • 空簇处理:如果某次迭代后某个簇没有分到任何框,这个簇的中心保持上一轮的值。
  • 排序规则:最后按面积降序输出。面积相同时按宽降序,再按高降序。

复杂度

每轮迭代对 个框各算 次 IOU,最多 轮,时间 。题目数据规模不大,直接模拟即可。

代码

import sys
import math

def iou(w1, h1, w2, h2):
    inter = min(w1, w2) * min(h1, h2)
    union = w1 * h1 + w2 * h2 - inter
    if union == 0:
        return 0.0
    return inter / union

def solve():
    data = sys.stdin.read().split()
    idx = 0
    N = int(data[idx]); idx += 1
    K = int(data[idx]); idx += 1
    T = int(data[idx]); idx += 1

    boxes = []
    for i in range(N):
        w = int(data[idx]); idx += 1
        h = int(data[idx]); idx += 1
        boxes.append((w, h))

    # 初始中心:前 K 个框
    centers = [(float(boxes[i][0]), float(boxes[i][1])) for i in range(K)]

    for iteration in range(T):
        # 分配:每个框归入 IOU 距离最近的簇
        clusters = [[] for _ in range(K)]
        for bw, bh in boxes:
            best_k = 0
            best_dist = float('inf')
            for k in range(K):
                cw, ch = centers[k]
                d = 1.0 - iou(bw, bh, cw, ch)
                if d < best_dist:
                    best_dist = d
                    best_k = k
            clusters[best_k].append((bw, bh))

        # 更新中心
        new_centers = []
        for k in range(K):
            if len(clusters[k]) == 0:
                new_centers.append(centers[k])
            else:
                sum_w = sum(b[0] for b in clusters[k])
                sum_h = sum(b[1] for b in clusters[k])
                cnt = len(clusters[k])
                nw = math.floor(sum_w / cnt)
                nh = math.floor(sum_h / cnt)
                new_centers.append((float(nw), float(nh)))

        # 收敛判断:总位移 < 1e-4
        total_disp = 0.0
        for k in range(K):
            dw = new_centers[k][0] - centers[k][0]
            dh = new_centers[k][1] - centers[k][1]
            total_disp += math.sqrt(dw * dw + dh * dh)

        centers = new_centers
        if total_disp < 1e-4:
            break

    # 按面积降序输出,面积相同按宽降序、高降序
    result = [(int(cw), int(ch)) for cw, ch in centers]
    result.sort(key=lambda x: (x[0] * x[1], x[0], x[1]), reverse=True)

    for w, h in result:
        print(w, h)

solve()