题目链接

K-Means聚类下的Anchor优化输出

题目描述

给定 个矩形框的宽和高,你需要使用一个定制的 K-Means 聚类算法来找出 个代表性的 Anchor 尺寸。

算法流程如下:

  1. 初始化: 取输入的前 个矩形框作为初始的 个聚类中心。

  2. 迭代: 最多进行 轮迭代。每一轮迭代包含: a. 分配: 将 个样本中的每一个都分配给与其距离最近的聚类中心。 b. 更新: 重新计算每个聚类的中心。新中心的宽/高是该聚类内所有样本宽/高的平均值,然后向下取整。如果一个聚类为空,其中心保持不变。

  3. 终止条件: 满足以下任一条件则停止迭代: a. 已达到最大迭代次数 。 b. 算法收敛:新旧中心之间的总“位移”小于 1e-4

  4. 输出: 按面积(宽 高)从大到小输出最终的 个中心。面积相同时,按宽降序,再按高降序。

距离度量:

  • 距离
  • 交集面积 =
  • 并集面积 =

解题思路

本题是对 K-Means 聚类算法的直接模拟。解题的关键是严格遵循题目给出的每一个细节,包括初始化、距离计算、中心点更新规则和终止条件。

  1. 数据结构

    • 使用一个结构体或类来表示矩形框(包含 wh)。
    • 用一个列表(如 vector)存储 个样本框。
    • 用另一个列表存储 个聚类中心。
  2. 距离计算函数

    • 实现一个函数 calculate_distance(box1, box2),该函数接收两个矩形框的尺寸,返回它们之间的距离
    • 注意:所有 IOU 和距离相关的计算都应使用浮点数(如 double)以保证精度。
  3. 主循环 (K-Means 迭代)

    • 初始化: 将输入的前 个样本框复制到聚类中心列表中。
    • 循环: 进行最多 次迭代。 a. 分配步骤: - 为 个聚类创建 个临时的列表,用于存放分配给它们的样本。 - 遍历 个样本框中的每一个。 - 对于每个样本,计算它到所有 个中心的距离,找到距离最小的那个中心,并将该样本添加到对应中心的临时列表中。 b. 更新步骤: - 创建一个列表 new_centers 用于存放本轮迭代计算出的新中心。 - 遍历 个聚类。 - 如果一个聚类不为空,计算其中所有样本的宽和高的平均值。然后将平均值向下取整,作为新中心的宽和高。 - 如果一个聚类为空,其中心保持不变,直接将旧中心复制为新中心。 c. 收敛检查: - 计算新旧中心之间的总位移。遍历 个中心,累加 calculate_distance(old_center[i], new_center[i])。 - 如果总位移小于 1e-4,则算法收敛,可以提前终止循环。 d. 更新中心: 将 new_centers 赋值给当前中心列表,准备下一轮迭代。
  4. 排序与输出

    • 迭代结束后,对最终的 个聚类中心进行排序。
    • 排序规则:首先按面积(w * h)降序;如果面积相等,则按宽度 w 降序;如果宽度仍相等,则按高度 h 降序。
    • 按格式输出排序后的结果。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <algorithm>

using namespace std;

struct Box {
    int w, h;
};

double calculate_iou(const Box& b1, const Box& b2) {
    double intersection_w = min(b1.w, b2.w);
    double intersection_h = min(b1.h, b2.h);
    double intersection_area = intersection_w * intersection_h;
    
    double union_area = (double)b1.w * b1.h + (double)b2.w * b2.h - intersection_area;
    
    return intersection_area / union_area;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int N, K, T;
    cin >> N >> K >> T;

    vector<Box> samples(N);
    for (int i = 0; i < N; ++i) {
        cin >> samples[i].w >> samples[i].h;
    }

    vector<Box> centers(K);
    for (int i = 0; i < K; ++i) {
        centers[i] = samples[i];
    }

    for (int t = 0; t < T; ++t) {
        vector<vector<Box>> clusters(K);
        for (const auto& sample : samples) {
            double min_dist = 2.0;
            int best_cluster = -1;
            for (int i = 0; i < K; ++i) {
                double dist = 1.0 - calculate_iou(sample, centers[i]);
                if (dist < min_dist) {
                    min_dist = dist;
                    best_cluster = i;
                }
            }
            clusters[best_cluster].push_back(sample);
        }

        vector<Box> new_centers(K);
        double total_displacement = 0.0;

        for (int i = 0; i < K; ++i) {
            if (clusters[i].empty()) {
                new_centers[i] = centers[i];
            } else {
                double sum_w = 0, sum_h = 0;
                for (const auto& box : clusters[i]) {
                    sum_w += box.w;
                    sum_h += box.h;
                }
                new_centers[i] = {(int)floor(sum_w / clusters[i].size()), (int)floor(sum_h / clusters[i].size())};
            }
            total_displacement += 1.0 - calculate_iou(centers[i], new_centers[i]);
        }

        centers = new_centers;

        if (total_displacement < 1e-4) {
            break;
        }
    }

    sort(centers.begin(), centers.end(), [](const Box& a, const Box& b) {
        long long area_a = (long long)a.w * a.h;
        long long area_b = (long long)b.w * b.h;
        if (area_a != area_b) return area_a > area_b;
        if (a.w != b.w) return a.w > b.w;
        return a.h > b.h;
    });

    for (int i = 0; i < K; ++i) {
        cout << centers[i].w << " " << centers[i].h << endl;
    }

    return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.List;
import java.util.Collections;
import java.util.Comparator;

class Box {
    int w, h;
    Box(int w, int h) {
        this.w = w;
        this.h = h;
    }
}

public class Main {
    private static double calculateIou(Box b1, Box b2) {
        double intersectionW = Math.min(b1.w, b2.w);
        double intersectionH = Math.min(b1.h, b2.h);
        double intersectionArea = intersectionW * intersectionH;
        
        double unionArea = (double)b1.w * b1.h + (double)b2.w * b2.h - intersectionArea;
        
        return intersectionArea / unionArea;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt();
        int K = sc.nextInt();
        int T = sc.nextInt();

        List<Box> samples = new ArrayList<>();
        for (int i = 0; i < N; i++) {
            samples.add(new Box(sc.nextInt(), sc.nextInt()));
        }

        List<Box> centers = new ArrayList<>();
        for (int i = 0; i < K; i++) {
            centers.add(samples.get(i));
        }

        for (int t = 0; t < T; t++) {
            List<List<Box>> clusters = new ArrayList<>();
            for (int i = 0; i < K; i++) {
                clusters.add(new ArrayList<>());
            }

            for (Box sample : samples) {
                double min_dist = 2.0;
                int best_cluster = -1;
                for (int i = 0; i < K; i++) {
                    double dist = 1.0 - calculateIou(sample, centers.get(i));
                    if (dist < min_dist) {
                        min_dist = dist;
                        best_cluster = i;
                    }
                }
                clusters.get(best_cluster).add(sample);
            }

            List<Box> newCenters = new ArrayList<>();
            double totalDisplacement = 0.0;

            for (int i = 0; i < K; i++) {
                if (clusters.get(i).isEmpty()) {
                    newCenters.add(centers.get(i));
                } else {
                    double sumW = 0, sumH = 0;
                    for (Box box : clusters.get(i)) {
                        sumW += box.w;
                        sumH += box.h;
                    }
                    newCenters.add(new Box((int)Math.floor(sumW / clusters.get(i).size()), (int)Math.floor(sumH / clusters.get(i).size())));
                }
                totalDisplacement += 1.0 - calculateIou(centers.get(i), newCenters.get(i));
            }

            centers = newCenters;
            if (totalDisplacement < 1e-4) {
                break;
            }
        }

        centers.sort((a, b) -> {
            long areaA = (long)a.w * a.h;
            long areaB = (long)b.w * b.h;
            if (areaA != areaB) return Long.compare(areaB, areaA);
            if (a.w != b.w) return Integer.compare(b.w, a.w);
            return Integer.compare(b.h, a.h);
        });

        for (Box center : centers) {
            System.out.println(center.w + " " + center.h);
        }
    }
}
import math

def calculate_iou(b1, b2):
    w1, h1 = b1
    w2, h2 = b2
    intersection_w = min(w1, w2)
    intersection_h = min(h1, h2)
    intersection_area = intersection_w * intersection_h
    
    union_area = w1 * h1 + w2 * h2 - intersection_area
    
    return intersection_area / union_area

def main():
    N, K, T = map(int, input().split())
    samples = [tuple(map(int, input().split())) for _ in range(N)]

    centers = samples[:K]

    for _ in range(T):
        clusters = [[] for _ in range(K)]
        for sample in samples:
            min_dist = 2.0
            best_cluster = -1
            for i in range(K):
                dist = 1.0 - calculate_iou(sample, centers[i])
                if dist < min_dist:
                    min_dist = dist
                    best_cluster = i
            clusters[best_cluster].append(sample)
        
        new_centers = []
        total_displacement = 0.0
        
        for i in range(K):
            if not clusters[i]:
                new_centers.append(centers[i])
            else:
                sum_w = sum(box[0] for box in clusters[i])
                sum_h = sum(box[1] for box in clusters[i])
                new_w = math.floor(sum_w / len(clusters[i]))
                new_h = math.floor(sum_h / len(clusters[i]))
                new_centers.append((new_w, new_h))
            
            total_displacement += 1.0 - calculate_iou(centers[i], new_centers[i])
            
        centers = new_centers
        
        if total_displacement < 1e-4:
            break
            
    centers.sort(key=lambda b: (b[0] * b[1], b[0], b[1]), reverse=True)
    
    for w, h in centers:
        print(w, h)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:K-Means聚类模拟
  • 时间复杂度:。其中 是最大迭代次数, 是样本数, 是聚类中心数。每一轮迭代都需要为 个样本计算到 个中心的距离。
  • 空间复杂度:。需要存储 个样本, 个中心,以及在迭代中临时存储聚类分配情况,空间主要由样本数决定。