题目链接

小红的星屑共鸣

题目描述

小红在二维平面上发现了 颗星屑,每颗星屑都有其坐标 。为了寻找能量最纯净的共鸣源,需要找出距离最近的两颗星屑。为了避免浮点数精度误差,请计算并输出这两颗星屑之间欧几里得距离的平方。

输入:

  • 第一行包含一个整数 ,表示星屑的数量。
  • 接下来的 行,每行包含两个整数 ,表示一颗星屑在平面上的坐标。

数据范围:

输出:

  • 输出一个整数,表示所有点对中最小的距离平方值。

解题思路

这是一个经典的最近点对问题 (Closest Pair of Points Problem)

  1. 算法选择

    • 暴力法检查所有点对的时间复杂度为 ,在 的规模下会超时。
    • 必须使用分治算法
  2. 分治步骤

    • 排序:首先将所有点按 坐标升序排列。
    • 递归:将点集从中线划分为左、右两部分。递归求出左右两部分内部的最小距离平方
    • 合并:在跨越中线的区域( 坐标到中线距离小于 )寻找可能的更小距离。
    • 剪枝:将中线区域的点按 坐标排序,对于每个点,只需检查其后 坐标差异在 范围内的有限个点即可。
  3. 数据类型

    • 坐标最大为 ,距离平方最大可达 ,因此在 C++ 和 Java 中需要使用 long longlong 类型,Python 则自带大数支持。

代码

#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>

using namespace std;

typedef long long LL;

struct Point {
    LL x, y;
};

bool compareX(const Point& a, const Point& b) {
    if (a.x != b.x) return a.x < b.x;
    return a.y < b.y;
}

LL distSq(Point p1, Point p2) {
    return (p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y) * (p1.y - p2.y);
}

LL solveRecursive(vector<Point>& pts, int l, int r) {
    if (r - l <= 3) {
        LL d = 8e18; 
        for (int i = l; i < r; ++i) {
            for (int j = i + 1; j <= r; ++j) {
                d = min(d, distSq(pts[i], pts[j]));
            }
        }
        return d;
    }

    int mid = l + (r - l) / 2;
    LL midX = pts[mid].x;
    LL d = min(solveRecursive(pts, l, mid), solveRecursive(pts, mid + 1, r));

    vector<Point> strip;
    for (int i = l; i <= r; ++i) {
        if ((pts[i].x - midX) * (pts[i].x - midX) < d) {
            strip.push_back(pts[i]);
        }
    }

    sort(strip.begin(), strip.end(), [](const Point& a, const Point& b) {
        return a.y < b.y;
    });

    for (int i = 0; i < strip.size(); ++i) {
        for (int j = i + 1; j < strip.size() && (strip[j].y - strip[i].y) * (strip[j].y - strip[i].y) < d; ++j) {
            d = min(d, distSq(strip[i], strip[j]));
        }
    }

    return d;
}

int main() {
    int n;
    cin >> n;
    vector<Point> pts(n);
    for (int i = 0; i < n; ++i) {
        cin >> pts[i].x >> pts[i].y;
    }

    sort(pts.begin(), pts.end(), compareX);
    cout << solveRecursive(pts, 0, n - 1) << endl;

    return 0;
}
import java.util.*;

public class Main {
    static class Point {
        long x, y;
        Point(long x, long y) {
            this.x = x;
            this.y = y;
        }
    }

    static long distSq(Point p1, Point p2) {
        return (p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y) * (p1.y - p2.y);
    }

    static long solve(Point[] pts, int l, int r) {
        if (r - l <= 3) {
            long d = Long.MAX_VALUE;
            for (int i = l; i < r; i++) {
                for (int j = i + 1; j <= r; j++) {
                    d = Math.min(d, distSq(pts[i], pts[j]));
                }
            }
            return d;
        }

        int mid = l + (r - l) / 2;
        long midX = pts[mid].x;
        long d = Math.min(solve(pts, l, mid), solve(pts, mid + 1, r));

        List<Point> strip = new ArrayList<>();
        for (int i = l; i <= r; i++) {
            if ((pts[i].x - midX) * (pts[i].x - midX) < d) {
                strip.add(pts[i]);
            }
        }

        strip.sort(Comparator.comparingLong(p -> p.y));

        for (int i = 0; i < strip.size(); i++) {
            for (int j = i + 1; j < strip.size() && (strip.get(j).y - strip.get(i).y) * (strip.get(j).y - strip.get(i).y) < d; j++) {
                d = Math.min(d, distSq(strip.get(i), strip.get(j)));
            }
        }
        return d;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        Point[] pts = new Point[n];
        for (int i = 0; i < n; i++) {
            pts[i] = new Point(sc.nextLong(), sc.nextLong());
        }

        Arrays.sort(pts, (p1, p2) -> {
            if (p1.x != p2.x) return Long.compare(p1.x, p2.x);
            return Long.compare(p1.y, p2.y);
        });
        System.out.println(solve(pts, 0, n - 1));
    }
}
def dist_sq(p1, p2):
    return (p1[0] - p2[0])**2 + (p1[1] - p2[1])**2

def solve_recursive(pts):
    n = len(pts)
    if n <= 3:
        d = float('inf')
        for i in range(n):
            for j in range(i + 1, n):
                d = min(d, dist_sq(pts[i], pts[j]))
        return d

    mid = n // 2
    mid_x = pts[mid][0]
    dl = solve_recursive(pts[:mid])
    dr = solve_recursive(pts[mid:])
    d = min(dl, dr)

    strip = [p for p in pts if (p[0] - mid_x)**2 < d]
    strip.sort(key=lambda p: p[1])

    for i in range(len(strip)):
        for j in range(i + 1, len(strip)):
            if (strip[j][1] - strip[i][1])**2 >= d:
                break
            d = min(d, dist_sq(strip[i], strip[j]))
    
    return d

def solve():
    n_str = input()
    if not n_str: return
    n = int(n_str)
    pts = []
    for _ in range(n):
        pts.append(tuple(map(int, input().split())))
    
    pts.sort()
    print(solve_recursive(pts))

if __name__ == "__main__":
    solve()

算法及复杂度

  • 算法:分治算法。通过左右划分递归求解,处理中线附近的潜在点对。
  • 时间复杂度:。本实现中合并阶段按 坐标重新排序,若改用归并排序思想可优化至
  • 空间复杂度: