题目链接
题目描述
给定 个矩形框的宽和高,你需要使用一个定制的 K-Means 聚类算法来找出
个代表性的 Anchor 尺寸。
算法流程如下:
-
初始化: 取输入的前
个矩形框作为初始的
个聚类中心。
-
迭代: 最多进行
轮迭代。每一轮迭代包含: a. 分配: 将
个样本中的每一个都分配给与其距离最近的聚类中心。 b. 更新: 重新计算每个聚类的中心。新中心的宽/高是该聚类内所有样本宽/高的平均值,然后向下取整。如果一个聚类为空,其中心保持不变。
-
终止条件: 满足以下任一条件则停止迭代: a. 已达到最大迭代次数
。 b. 算法收敛:新旧中心之间的总“位移”小于
1e-4。 -
输出: 按面积(宽
高)从大到小输出最终的
个中心。面积相同时,按宽降序,再按高降序。
距离度量:
- 距离
- 交集面积 =
- 并集面积 =
解题思路
本题是对 K-Means 聚类算法的直接模拟。解题的关键是严格遵循题目给出的每一个细节,包括初始化、距离计算、中心点更新规则和终止条件。
-
数据结构
- 使用一个结构体或类来表示矩形框(包含
w和h)。 - 用一个列表(如
vector)存储个样本框。
- 用另一个列表存储
个聚类中心。
- 使用一个结构体或类来表示矩形框(包含
-
距离计算函数
- 实现一个函数
calculate_distance(box1, box2),该函数接收两个矩形框的尺寸,返回它们之间的距离。
- 注意:所有 IOU 和距离相关的计算都应使用浮点数(如
double)以保证精度。
- 实现一个函数
-
主循环 (K-Means 迭代)
- 初始化: 将输入的前
个样本框复制到聚类中心列表中。
- 循环: 进行最多
次迭代。 a. 分配步骤: - 为
个聚类创建
个临时的列表,用于存放分配给它们的样本。 - 遍历
个样本框中的每一个。 - 对于每个样本,计算它到所有
个中心的距离,找到距离最小的那个中心,并将该样本添加到对应中心的临时列表中。 b. 更新步骤: - 创建一个列表
new_centers用于存放本轮迭代计算出的新中心。 - 遍历个聚类。 - 如果一个聚类不为空,计算其中所有样本的宽和高的平均值。然后将平均值向下取整,作为新中心的宽和高。 - 如果一个聚类为空,其中心保持不变,直接将旧中心复制为新中心。 c. 收敛检查: - 计算新旧中心之间的总位移。遍历
个中心,累加
calculate_distance(old_center[i], new_center[i])。 - 如果总位移小于1e-4,则算法收敛,可以提前终止循环。 d. 更新中心: 将new_centers赋值给当前中心列表,准备下一轮迭代。
- 初始化: 将输入的前
-
排序与输出
- 迭代结束后,对最终的
个聚类中心进行排序。
- 排序规则:首先按面积(
w * h)降序;如果面积相等,则按宽度w降序;如果宽度仍相等,则按高度h降序。 - 按格式输出排序后的结果。
- 迭代结束后,对最终的
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <algorithm>
using namespace std;
struct Box {
int w, h;
};
double calculate_iou(const Box& b1, const Box& b2) {
double intersection_w = min(b1.w, b2.w);
double intersection_h = min(b1.h, b2.h);
double intersection_area = intersection_w * intersection_h;
double union_area = (double)b1.w * b1.h + (double)b2.w * b2.h - intersection_area;
return intersection_area / union_area;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int N, K, T;
cin >> N >> K >> T;
vector<Box> samples(N);
for (int i = 0; i < N; ++i) {
cin >> samples[i].w >> samples[i].h;
}
vector<Box> centers(K);
for (int i = 0; i < K; ++i) {
centers[i] = samples[i];
}
for (int t = 0; t < T; ++t) {
vector<vector<Box>> clusters(K);
for (const auto& sample : samples) {
double min_dist = 2.0;
int best_cluster = -1;
for (int i = 0; i < K; ++i) {
double dist = 1.0 - calculate_iou(sample, centers[i]);
if (dist < min_dist) {
min_dist = dist;
best_cluster = i;
}
}
clusters[best_cluster].push_back(sample);
}
vector<Box> new_centers(K);
double total_displacement = 0.0;
for (int i = 0; i < K; ++i) {
if (clusters[i].empty()) {
new_centers[i] = centers[i];
} else {
double sum_w = 0, sum_h = 0;
for (const auto& box : clusters[i]) {
sum_w += box.w;
sum_h += box.h;
}
new_centers[i] = {(int)floor(sum_w / clusters[i].size()), (int)floor(sum_h / clusters[i].size())};
}
total_displacement += 1.0 - calculate_iou(centers[i], new_centers[i]);
}
centers = new_centers;
if (total_displacement < 1e-4) {
break;
}
}
sort(centers.begin(), centers.end(), [](const Box& a, const Box& b) {
long long area_a = (long long)a.w * a.h;
long long area_b = (long long)b.w * b.h;
if (area_a != area_b) return area_a > area_b;
if (a.w != b.w) return a.w > b.w;
return a.h > b.h;
});
for (int i = 0; i < K; ++i) {
cout << centers[i].w << " " << centers[i].h << endl;
}
return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.List;
import java.util.Collections;
import java.util.Comparator;
class Box {
int w, h;
Box(int w, int h) {
this.w = w;
this.h = h;
}
}
public class Main {
private static double calculateIou(Box b1, Box b2) {
double intersectionW = Math.min(b1.w, b2.w);
double intersectionH = Math.min(b1.h, b2.h);
double intersectionArea = intersectionW * intersectionH;
double unionArea = (double)b1.w * b1.h + (double)b2.w * b2.h - intersectionArea;
return intersectionArea / unionArea;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
int K = sc.nextInt();
int T = sc.nextInt();
List<Box> samples = new ArrayList<>();
for (int i = 0; i < N; i++) {
samples.add(new Box(sc.nextInt(), sc.nextInt()));
}
List<Box> centers = new ArrayList<>();
for (int i = 0; i < K; i++) {
centers.add(samples.get(i));
}
for (int t = 0; t < T; t++) {
List<List<Box>> clusters = new ArrayList<>();
for (int i = 0; i < K; i++) {
clusters.add(new ArrayList<>());
}
for (Box sample : samples) {
double min_dist = 2.0;
int best_cluster = -1;
for (int i = 0; i < K; i++) {
double dist = 1.0 - calculateIou(sample, centers.get(i));
if (dist < min_dist) {
min_dist = dist;
best_cluster = i;
}
}
clusters.get(best_cluster).add(sample);
}
List<Box> newCenters = new ArrayList<>();
double totalDisplacement = 0.0;
for (int i = 0; i < K; i++) {
if (clusters.get(i).isEmpty()) {
newCenters.add(centers.get(i));
} else {
double sumW = 0, sumH = 0;
for (Box box : clusters.get(i)) {
sumW += box.w;
sumH += box.h;
}
newCenters.add(new Box((int)Math.floor(sumW / clusters.get(i).size()), (int)Math.floor(sumH / clusters.get(i).size())));
}
totalDisplacement += 1.0 - calculateIou(centers.get(i), newCenters.get(i));
}
centers = newCenters;
if (totalDisplacement < 1e-4) {
break;
}
}
centers.sort((a, b) -> {
long areaA = (long)a.w * a.h;
long areaB = (long)b.w * b.h;
if (areaA != areaB) return Long.compare(areaB, areaA);
if (a.w != b.w) return Integer.compare(b.w, a.w);
return Integer.compare(b.h, a.h);
});
for (Box center : centers) {
System.out.println(center.w + " " + center.h);
}
}
}
import math
def calculate_iou(b1, b2):
w1, h1 = b1
w2, h2 = b2
intersection_w = min(w1, w2)
intersection_h = min(h1, h2)
intersection_area = intersection_w * intersection_h
union_area = w1 * h1 + w2 * h2 - intersection_area
return intersection_area / union_area
def main():
N, K, T = map(int, input().split())
samples = [tuple(map(int, input().split())) for _ in range(N)]
centers = samples[:K]
for _ in range(T):
clusters = [[] for _ in range(K)]
for sample in samples:
min_dist = 2.0
best_cluster = -1
for i in range(K):
dist = 1.0 - calculate_iou(sample, centers[i])
if dist < min_dist:
min_dist = dist
best_cluster = i
clusters[best_cluster].append(sample)
new_centers = []
total_displacement = 0.0
for i in range(K):
if not clusters[i]:
new_centers.append(centers[i])
else:
sum_w = sum(box[0] for box in clusters[i])
sum_h = sum(box[1] for box in clusters[i])
new_w = math.floor(sum_w / len(clusters[i]))
new_h = math.floor(sum_h / len(clusters[i]))
new_centers.append((new_w, new_h))
total_displacement += 1.0 - calculate_iou(centers[i], new_centers[i])
centers = new_centers
if total_displacement < 1e-4:
break
centers.sort(key=lambda b: (b[0] * b[1], b[0], b[1]), reverse=True)
for w, h in centers:
print(w, h)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:K-Means聚类模拟
- 时间复杂度:
。其中
是最大迭代次数,
是样本数,
是聚类中心数。每一轮迭代都需要为
个样本计算到
个中心的距离。
- 空间复杂度:
。需要存储
个样本,
个中心,以及在迭代中临时存储聚类分配情况,空间主要由样本数决定。

京公网安备 11010502036488号