题目链接

二分 K-means子网分割

题目描述

给定 (期望子网数)和 个二维站点坐标,采用“二分 K-means”的思路进行子网分割。过程如下:

  1. 从一个包含所有站点的簇开始。
  2. 重复 次以下操作: a. 在当前所有簇中,找到 SSE (簇内点到簇心平方和) 最大的那个簇。 b. 将这个簇通过一次标准的 K-means (K=2) 分裂成两个新的子簇。 c. 每次分裂后,按降序输出当前所有簇的规模(站点数)。

K-means (K=2) 的具体规则:

  • 初始簇心: 选择该簇中 x 坐标最小的点与 x 坐标最大的点。平局时按 y 坐标、再按输入次序打破。
  • 迭代: 重复“分配”和“更新”步骤,直到收敛。
    • 分配: 每个点按欧氏距离分配给最近的簇心。
    • 更新: 以簇内点的平均坐标作为新簇心。
    • 特殊规则: 若某簇为空,其簇心不变。
  • 收敛条件: 簇心总移动距离小于 1e-6 或迭代满 1000 次。

解题思路

这是一个高度定制化的聚类模拟问题,需要严格遵循题目设定的流程。解题的关键在于将整个“二分 K-means”过程拆解成模块化的、可管理的部分。

1. 整体框架:自顶向下的二分 (Bisecting)

整个算法是一个主循环,从 k=1 个簇(包含所有点)开始,循环 N-1 次,每次迭代簇的数量加一,直到达到 N

在每一次主循环中,执行以下步骤:

  • a. 选择分裂目标:遍历当前所有的簇,为每个簇计算其 SSE (Sum of Squared Errors)。SSE 是衡量簇内数据点分散程度的指标,其值越大表示簇越应该被分裂。选择 SSE 最大的簇作为本次的分裂目标。
  • b. 执行分裂:对选定的目标簇调用一个 k_means_split 子函数,将其分裂成两个新的子簇。
  • c. 更新簇列表:从当前簇列表中移除被分裂的父簇,并加入两个新的子簇。
  • d. 输出结果:获取当前所有簇的规模,降序排序后输出。

2. 核心子过程:K-means 分裂 (K=2)

这个函数接收一个点集,并将其划分为两个子集。

  • a. 数据准备:为了处理平局规则,我们需要在最初读入数据时,就将每个点的坐标和它的原始输入顺序索引绑定在一起。

  • b. 初始化簇心

    • 根据规则“x 坐标 -> y 坐标 -> 原始索引”进行排序,找到 x 最小和 x 最大的两个点作为初始簇心。
    • 需要自定义排序逻辑来实现这个复杂的比较规则。
  • c. 迭代优化

    • 进入一个循环,最多迭代 1000 次。
    • 保存旧簇心:在循环开始时,保存当前两个簇心的位置,用于后续计算移动距离。
    • 分配 (Assignment):创建两个空的子簇。遍历当前簇内的所有点,计算它到两个簇心的欧氏距离,并将其分配给距离更近的那个簇心对应的子簇。
    • 更新 (Update)
      • 遍历两个新生成的子簇。
      • 如果子簇非空,计算其所有点的坐标均值,作为新的簇心。
      • 如果子簇为空,其簇心保持不变(等于上一轮的旧簇心)。
    • 检查收敛:计算两个簇心从旧位置到新位置的移动距离之和。如果总距离小于 1e-6,则提前跳出循环。
  • d. 返回结果:循环结束后,返回最终形成的两个子簇(点集)。

3. SSE (Sum of Squared Errors) 计算

这是一个辅助函数,接收一个点集,计算其 SSE:

  • 首先,计算该点集的质心(所有点坐标的均值)。
  • 然后,遍历点集中的每个点,累加其到质心的欧氏距离的平方

通过组合这些模块,就可以完整地模拟整个流程。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <algorithm>
#include <iomanip>

using namespace std;

struct Point {
    int id;
    double x, y;
};

struct Cluster {
    vector<int> point_indices;
    pair<double, double> centroid;
    double sse;
};

double euclidean_dist_sq(double x1, double y1, double x2, double y2) {
    return (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2);
}

// K-means (K=2) 分裂函数
pair<vector<int>, vector<int>> k_means_split(const vector<int>& point_indices, const vector<Point>& all_points) {
    if (point_indices.size() <= 1) {
        return {point_indices, {}};
    }

    // 1. 初始化簇心
    auto points_in_cluster = point_indices;
    sort(points_in_cluster.begin(), points_in_cluster.end(), [&](int a, int b) {
        const auto& p1 = all_points[a];
        const auto& p2 = all_points[b];
        if (p1.x != p2.x) return p1.x < p2.x;
        if (p1.y != p2.y) return p1.y < p2.y;
        return p1.id < p2.id;
    });
    
    Point center1_pt = all_points[points_in_cluster.front()];
    Point center2_pt = all_points[points_in_cluster.back()];
    
    // 如果所有点都一样,特殊处理
    if (center1_pt.id == center2_pt.id) {
         vector<int> sub1, sub2;
         for(size_t i = 0; i < point_indices.size(); ++i) {
             if (i < point_indices.size() / 2) sub1.push_back(point_indices[i]);
             else sub2.push_back(point_indices[i]);
         }
         return {sub1, sub2};
    }

    pair<double, double> centroid1 = {center1_pt.x, center1_pt.y};
    pair<double, double> centroid2 = {center2_pt.x, center2_pt.y};

    vector<int> cluster1, cluster2;

    for (int iter = 0; iter < 1000; ++iter) {
        cluster1.clear();
        cluster2.clear();

        // 2. 分配
        for (int idx : point_indices) {
            double dist1 = euclidean_dist_sq(all_points[idx].x, all_points[idx].y, centroid1.first, centroid1.second);
            double dist2 = euclidean_dist_sq(all_points[idx].x, all_points[idx].y, centroid2.first, centroid2.second);
            if (dist1 <= dist2) {
                cluster1.push_back(idx);
            } else {
                cluster2.push_back(idx);
            }
        }

        // 3. 更新
        pair<double, double> old_centroid1 = centroid1;
        pair<double, double> old_centroid2 = centroid2;

        if (!cluster1.empty()) {
            double sum_x = 0, sum_y = 0;
            for (int idx : cluster1) { sum_x += all_points[idx].x; sum_y += all_points[idx].y; }
            centroid1 = {sum_x / cluster1.size(), sum_y / cluster1.size()};
        }
        if (!cluster2.empty()) {
            double sum_x = 0, sum_y = 0;
            for (int idx : cluster2) { sum_x += all_points[idx].x; sum_y += all_points[idx].y; }
            centroid2 = {sum_x / cluster2.size(), sum_y / cluster2.size()};
        }

        // 4. 检查收敛
        double move_dist = sqrt(euclidean_dist_sq(centroid1.first, centroid1.second, old_centroid1.first, old_centroid1.second)) +
                           sqrt(euclidean_dist_sq(centroid2.first, centroid2.second, old_centroid2.first, old_centroid2.second));
        if (move_dist < 1e-6) {
            break;
        }
    }
    return {cluster1, cluster2};
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n_target, m_points;
    cin >> n_target >> m_points;
    vector<Point> all_points(m_points);
    for (int i = 0; i < m_points; ++i) {
        all_points[i].id = i;
        cin >> all_points[i].x >> all_points[i].y;
    }

    vector<Cluster> clusters;
    Cluster initial_cluster;
    initial_cluster.point_indices.resize(m_points);
    iota(initial_cluster.point_indices.begin(), initial_cluster.point_indices.end(), 0);
    clusters.push_back(initial_cluster);

    for (int k = 1; k < n_target; ++k) {
        // 1. 找到 SSE 最大的簇
        int split_idx = -1;
        double max_sse = -1.0;
        for (size_t i = 0; i < clusters.size(); ++i) {
            if (clusters[i].point_indices.size() <= 1) {
                clusters[i].sse = 0;
                continue;
            }
            double sum_x = 0, sum_y = 0;
            for (int p_idx : clusters[i].point_indices) {
                sum_x += all_points[p_idx].x;
                sum_y += all_points[p_idx].y;
            }
            clusters[i].centroid = {sum_x / clusters[i].point_indices.size(), sum_y / clusters[i].point_indices.size()};
            
            double current_sse = 0;
            for (int p_idx : clusters[i].point_indices) {
                current_sse += euclidean_dist_sq(all_points[p_idx].x, all_points[p_idx].y, clusters[i].centroid.first, clusters[i].centroid.second);
            }
            clusters[i].sse = current_sse;

            if (clusters[i].sse > max_sse) {
                max_sse = clusters[i].sse;
                split_idx = i;
            }
        }

        // 2. 分裂
        vector<int> to_split_indices = clusters[split_idx].point_indices;
        clusters.erase(clusters.begin() + split_idx);
        
        auto new_clusters_indices = k_means_split(to_split_indices, all_points);
        
        if (!new_clusters_indices.first.empty()) clusters.push_back({new_clusters_indices.first});
        if (!new_clusters_indices.second.empty()) clusters.push_back({new_clusters_indices.second});

        // 3. 输出
        vector<int> sizes;
        for (const auto& c : clusters) {
            sizes.push_back(c.point_indices.size());
        }
        sort(sizes.rbegin(), sizes.rend());
        for (size_t i = 0; i < sizes.size(); ++i) {
            cout << sizes[i] << (i == sizes.size() - 1 ? "" : " ");
        }
        cout << "\n";
    }

    return 0;
}
import java.util.*;

class Point {
    int id;
    double x, y;

    Point(int id, double x, double y) {
        this.id = id;
        this.x = x;
        this.y = y;
    }
}

class Cluster {
    List<Integer> pointIndices;
    double centroidX, centroidY;
    double sse;

    Cluster(List<Integer> pointIndices) {
        this.pointIndices = new ArrayList<>(pointIndices);
    }
}

public class Main {
    static List<Point> allPoints;

    private static double euclideanDistSq(double x1, double y1, double x2, double y2) {
        return (x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2);
    }

    private static List<List<Integer>> kMeansSplit(List<Integer> pointIndices) {
        List<List<Integer>> result = new ArrayList<>();
        if (pointIndices.size() <= 1) {
            result.add(new ArrayList<>(pointIndices));
            result.add(new ArrayList<>());
            return result;
        }

        List<Integer> pointsInCluster = new ArrayList<>(pointIndices);
        pointsInCluster.sort((a, b) -> {
            Point p1 = allPoints.get(a);
            Point p2 = allPoints.get(b);
            if (p1.x != p2.x) return Double.compare(p1.x, p2.x);
            if (p1.y != p2.y) return Double.compare(p1.y, p2.y);
            return Integer.compare(p1.id, p2.id);
        });

        Point center1Pt = allPoints.get(pointsInCluster.get(0));
        Point center2Pt = allPoints.get(pointsInCluster.get(pointsInCluster.size() - 1));
        
        if (center1Pt.id == center2Pt.id) {
             List<Integer> sub1 = new ArrayList<>();
             List<Integer> sub2 = new ArrayList<>();
             for(int i = 0; i < pointIndices.size(); ++i) {
                 if (i < pointIndices.size() / 2) sub1.add(pointIndices.get(i));
                 else sub2.add(pointIndices.get(i));
             }
             result.add(sub1);
             result.add(sub2);
             return result;
        }


        double c1x = center1Pt.x, c1y = center1Pt.y;
        double c2x = center2Pt.x, c2y = center2Pt.y;

        List<Integer> cluster1 = new ArrayList<>();
        List<Integer> cluster2 = new ArrayList<>();

        for (int iter = 0; iter < 1000; iter++) {
            cluster1.clear();
            cluster2.clear();

            for (int idx : pointIndices) {
                Point p = allPoints.get(idx);
                if (euclideanDistSq(p.x, p.y, c1x, c1y) <= euclideanDistSq(p.x, p.y, c2x, c2y)) {
                    cluster1.add(idx);
                } else {
                    cluster2.add(idx);
                }
            }

            double oldC1x = c1x, oldC1y = c1y;
            double oldC2x = c2x, oldC2y = c2y;

            if (!cluster1.isEmpty()) {
                double sumX = 0, sumY = 0;
                for (int idx : cluster1) {
                    sumX += allPoints.get(idx).x;
                    sumY += allPoints.get(idx).y;
                }
                c1x = sumX / cluster1.size();
                c1y = sumY / cluster1.size();
            }
            if (!cluster2.isEmpty()) {
                double sumX = 0, sumY = 0;
                for (int idx : cluster2) {
                    sumX += allPoints.get(idx).x;
                    sumY += allPoints.get(idx).y;
                }
                c2x = sumX / cluster2.size();
                c2y = sumY / cluster2.size();
            }
            
            double moveDist = Math.sqrt(euclideanDistSq(c1x, c1y, oldC1x, oldC1y)) +
                              Math.sqrt(euclideanDistSq(c2x, c2y, oldC2x, oldC2y));

            if (moveDist < 1e-6) {
                break;
            }
        }
        result.add(cluster1);
        result.add(cluster2);
        return result;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int nTarget = sc.nextInt();
        int mPoints = sc.nextInt();
        allPoints = new ArrayList<>();
        for (int i = 0; i < mPoints; i++) {
            allPoints.add(new Point(i, sc.nextDouble(), sc.nextDouble()));
        }

        List<Cluster> clusters = new ArrayList<>();
        List<Integer> initialIndices = new ArrayList<>();
        for (int i = 0; i < mPoints; i++) initialIndices.add(i);
        clusters.add(new Cluster(initialIndices));

        for (int k = 1; k < nTarget; k++) {
            int splitIdx = -1;
            double maxSse = -1.0;

            for (int i = 0; i < clusters.size(); i++) {
                Cluster c = clusters.get(i);
                if (c.pointIndices.size() <= 1) {
                    c.sse = 0;
                    continue;
                }
                double sumX = 0, sumY = 0;
                for (int pIdx : c.pointIndices) {
                    sumX += allPoints.get(pIdx).x;
                    sumY += allPoints.get(pIdx).y;
                }
                c.centroidX = sumX / c.pointIndices.size();
                c.centroidY = sumY / c.pointIndices.size();
                
                double currentSse = 0;
                for (int pIdx : c.pointIndices) {
                    Point p = allPoints.get(pIdx);
                    currentSse += euclideanDistSq(p.x, p.y, c.centroidX, c.centroidY);
                }
                c.sse = currentSse;
                
                if (c.sse > maxSse) {
                    maxSse = c.sse;
                    splitIdx = i;
                }
            }

            List<Integer> toSplitIndices = clusters.get(splitIdx).pointIndices;
            clusters.remove(splitIdx);

            List<List<Integer>> newClustersIndices = kMeansSplit(toSplitIndices);

            if (!newClustersIndices.get(0).isEmpty()) clusters.add(new Cluster(newClustersIndices.get(0)));
            if (!newClustersIndices.get(1).isEmpty()) clusters.add(new Cluster(newClustersIndices.get(1)));
            
            List<Integer> sizes = new ArrayList<>();
            for (Cluster c : clusters) {
                sizes.add(c.pointIndices.size());
            }
            sizes.sort(Collections.reverseOrder());

            for (int i = 0; i < sizes.size(); i++) {
                System.out.print(sizes.get(i) + (i == sizes.size() - 1 ? "" : " "));
            }
            System.out.println();
        }
    }
}
import math

# 使用全局变量来存储所有点的信息,便于在排序时通过索引访问
all_points = []

def euclidean_dist_sq(p1, p2):
    return (p1[0] - p2[0])**2 + (p1[1] - p2[1])**2

def k_means_split(point_indices):
    if len(point_indices) <= 1:
        return point_indices, []

    # 1. 初始化簇心
    points_in_cluster = sorted(point_indices, key=lambda i: (all_points[i][0], all_points[i][1], all_points[i][2]))
    
    center1_pt = all_points[points_in_cluster[0]]
    center2_pt = all_points[points_in_cluster[-1]]

    # 如果所有点都重合
    if center1_pt[2] == center2_pt[2]:
        mid = len(point_indices) // 2
        return point_indices[:mid], point_indices[mid:]

    centroid1 = (center1_pt[0], center1_pt[1])
    centroid2 = (center2_pt[0], center2_pt[1])

    for _ in range(1000):
        cluster1_indices, cluster2_indices = [], []
        
        # 2. 分配
        for idx in point_indices:
            p = (all_points[idx][0], all_points[idx][1])
            if euclidean_dist_sq(p, centroid1) <= euclidean_dist_sq(p, centroid2):
                cluster1_indices.append(idx)
            else:
                cluster2_indices.append(idx)

        # 3. 更新
        old_centroid1, old_centroid2 = centroid1, centroid2
        
        if cluster1_indices:
            sum_x = sum(all_points[i][0] for i in cluster1_indices)
            sum_y = sum(all_points[i][1] for i in cluster1_indices)
            centroid1 = (sum_x / len(cluster1_indices), sum_y / len(cluster1_indices))

        if cluster2_indices:
            sum_x = sum(all_points[i][0] for i in cluster2_indices)
            sum_y = sum(all_points[i][1] for i in cluster2_indices)
            centroid2 = (sum_x / len(cluster2_indices), sum_y / len(cluster2_indices))

        # 4. 检查收敛
        move_dist = math.sqrt(euclidean_dist_sq(centroid1, old_centroid1)) + math.sqrt(euclidean_dist_sq(centroid2, old_centroid2))
        if move_dist < 1e-6:
            break
            
    return cluster1_indices, cluster2_indices

def main():
    global all_points
    n_target = int(input())
    m_points = int(input())
    for i in range(m_points):
        x, y = map(int, input().split())
        all_points.append((x, y, i))

    clusters = [list(range(m_points))]

    for _ in range(n_target - 1):
        # 1. 找到 SSE 最大的簇
        split_idx = -1
        max_sse = -1.0
        
        for i, cluster_indices in enumerate(clusters):
            if len(cluster_indices) <= 1:
                continue

            sum_x = sum(all_points[p_idx][0] for p_idx in cluster_indices)
            sum_y = sum(all_points[p_idx][1] for p_idx in cluster_indices)
            centroid = (sum_x / len(cluster_indices), sum_y / len(cluster_indices))

            current_sse = sum(euclidean_dist_sq((all_points[p_idx][0], all_points[p_idx][1]), centroid) for p_idx in cluster_indices)

            if current_sse > max_sse:
                max_sse = current_sse
                split_idx = i

        # 2. 分裂
        to_split_indices = clusters.pop(split_idx)
        new_cluster1, new_cluster2 = k_means_split(to_split_indices)
        
        if new_cluster1: clusters.append(new_cluster1)
        if new_cluster2: clusters.append(new_cluster2)

        # 3. 输出
        sizes = sorted([len(c) for c in clusters], reverse=True)
        print(*sizes)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:模拟, K-means 聚类
  • 时间复杂度:,其中 是目标簇数, 是总点数, 是 K-means 的最大迭代次数(本题中为 1000)。外层循环执行 次。在循环内部,计算所有簇的 SSE 需要 。K-means 分裂过程对一个大小为 的簇进行 次迭代,每次迭代耗时 ,总耗时为 。在最坏情况下, 可能接近
  • 空间复杂度:,主要用于存储所有点的坐标和簇的分配情况。