小红的星屑共鸣

题目分析

给定 个二维平面上的点,求所有点对之间欧几里得距离的平方的最小值。

这就是经典的最近点对问题,只不过输出的是距离的平方而非距离本身。

思路

分治法(Closest Pair of Points)

暴力枚举所有点对的时间复杂度为 ,对于 较大的情况无法通过。经典做法是分治

  1. 预处理:将所有点按 坐标排序。
  1. 分治:取中间点的 坐标作为分割线,将点集分为左右两半,分别递归求解左半和右半的最近点对距离
  1. 合并(关键步骤):最近点对可能一个在左半、一个在右半。只需考虑到分割线距离小于 的点(即"strip"区域)。将 strip 中的点按 坐标排序后,对每个点只需检查 坐标差小于 的后续点。数学上可以证明,每个点最多只需检查常数个(不超过 7 个)候选点,因此合并步骤是线性的。
  1. 归并排序优化:为避免每层都对 strip 重新按 排序,在递归过程中同时完成按 坐标的归并排序(类似归并排序的 merge 步骤),使总复杂度保持

> 注意:由于我们比较的是距离的平方,在 strip 筛选时用 代替 ,避免浮点运算。

Python 的特殊处理:Python 递归较慢,分治法容易超时。改用随机增量法 + 网格哈希:随机打乱点的顺序,维护边长为 的网格,每插入一个新点只需检查周围 的网格。当最近距离更新时重建网格。期望时间复杂度

代码

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef pair<ll,ll> pll;

ll dist2(const pll& a, const pll& b){
    return (a.first-b.first)*(a.first-b.first)+(a.second-b.second)*(a.second-b.second);
}

ll solve(vector<pll>& pts, int l, int r){
    if(r - l <= 3){
        ll res = LLONG_MAX;
        for(int i = l; i < r; i++)
            for(int j = i+1; j < r; j++)
                res = min(res, dist2(pts[i], pts[j]));
        sort(pts.begin()+l, pts.begin()+r, [](const pll& a, const pll& b){ return a.second < b.second; });
        return res;
    }
    int mid = (l+r)/2;
    ll midx = pts[mid].first;
    ll d = min(solve(pts, l, mid), solve(pts, mid, r));

    // merge two halves sorted by y
    vector<pll> tmp(r-l);
    int i = l, j = mid, k = 0;
    while(i < mid && j < r){
        if(pts[i].second <= pts[j].second) tmp[k++] = pts[i++];
        else tmp[k++] = pts[j++];
    }
    while(i < mid) tmp[k++] = pts[i++];
    while(j < r) tmp[k++] = pts[j++];
    copy(tmp.begin(), tmp.end(), pts.begin()+l);

    // strip
    vector<pll> strip;
    for(int i = l; i < r; i++){
        ll dx = pts[i].first - midx;
        if(dx * dx < d)
            strip.push_back(pts[i]);
    }

    for(int i = 0; i < (int)strip.size(); i++){
        for(int j = i+1; j < (int)strip.size(); j++){
            ll dy = strip[j].second - strip[i].second;
            if(dy * dy >= d) break;
            d = min(d, dist2(strip[i], strip[j]));
        }
    }
    return d;
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    cin >> n;
    vector<pll> pts(n);
    for(int i = 0; i < n; i++) cin >> pts[i].first >> pts[i].second;

    sort(pts.begin(), pts.end());

    cout << solve(pts, 0, n) << "\n";
    return 0;
}
import java.util.*;
import java.io.*;

public class Main {
    static long dist2(long[] a, long[] b) {
        return (a[0]-b[0])*(a[0]-b[0]) + (a[1]-b[1])*(a[1]-b[1]);
    }

    static long solve(long[][] pts, int l, int r) {
        if (r - l <= 3) {
            long res = Long.MAX_VALUE;
            for (int i = l; i < r; i++)
                for (int j = i+1; j < r; j++)
                    res = Math.min(res, dist2(pts[i], pts[j]));
            Arrays.sort(pts, l, r, (a, b) -> Long.compare(a[1], b[1]));
            return res;
        }
        int mid = (l + r) / 2;
        long midx = pts[mid][0];
        long d = Math.min(solve(pts, l, mid), solve(pts, mid, r));

        long[][] tmp = new long[r - l][2];
        int i = l, j = mid, k = 0;
        while (i < mid && j < r) {
            if (pts[i][1] <= pts[j][1]) tmp[k++] = pts[i++];
            else tmp[k++] = pts[j++];
        }
        while (i < mid) tmp[k++] = pts[i++];
        while (j < r) tmp[k++] = pts[j++];
        System.arraycopy(tmp, 0, pts, l, r - l);

        List<long[]> strip = new ArrayList<>();
        for (int p = l; p < r; p++) {
            long dx = pts[p][0] - midx;
            if (dx * dx < d) strip.add(pts[p]);
        }

        for (int p = 0; p < strip.size(); p++) {
            for (int q = p + 1; q < strip.size(); q++) {
                long dy = strip.get(q)[1] - strip.get(p)[1];
                if (dy * dy >= d) break;
                d = Math.min(d, dist2(strip.get(p), strip.get(q)));
            }
        }
        return d;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());
        long[][] pts = new long[n][2];
        for (int i = 0; i < n; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            pts[i][0] = Long.parseLong(st.nextToken());
            pts[i][1] = Long.parseLong(st.nextToken());
        }
        Arrays.sort(pts, (a, b) -> Long.compare(a[0], b[0]));
        System.out.println(solve(pts, 0, n));
    }
}
import sys
from random import shuffle

def solve():
    input_data = sys.stdin.buffer.read().split()
    idx = 0
    n = int(input_data[idx]); idx += 1
    pts = []
    for i in range(n):
        x = int(input_data[idx]); idx += 1
        y = int(input_data[idx]); idx += 1
        pts.append((x, y))

    if n <= 1:
        print(0)
        return

    shuffle(pts)

    def dist2(a, b):
        return (a[0]-b[0])**2 + (a[1]-b[1])**2

    d = dist2(pts[0], pts[1])
    if d == 0:
        print(0)
        return

    from math import isqrt

    def make_grid(pts_list, sz):
        g = {}
        for p in pts_list:
            gx = p[0] // sz
            gy = p[1] // sz
            if (gx, gy) not in g:
                g[(gx, gy)] = []
            g[(gx, gy)].append(p)
        return g

    sz = max(1, isqrt(d))
    grid = make_grid(pts[:2], sz)

    for i in range(2, n):
        p = pts[i]
        gx = p[0] // sz
        gy = p[1] // sz
        mind = d
        for dx in range(-2, 3):
            for dy in range(-2, 3):
                cell = grid.get((gx+dx, gy+dy))
                if cell:
                    for q in cell:
                        dd = dist2(p, q)
                        if dd < mind:
                            mind = dd
        if mind < d:
            d = mind
            if d == 0:
                print(0)
                return
            sz = max(1, isqrt(d))
            grid = make_grid(pts[:i+1], sz)
        else:
            if (gx, gy) not in grid:
                grid[(gx, gy)] = []
            grid[(gx, gy)].append(p)

    print(d)

solve()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l));
rl.on('close', () => {
    const n = parseInt(lines[0]);
    const pts = [];
    for (let i = 1; i <= n; i++) {
        const [x, y] = lines[i].split(' ').map(Number);
        pts.push([x, y]);
    }
    pts.sort((a, b) => a[0] - b[0] || a[1] - b[1]);

    function dist2(a, b) {
        return (a[0]-b[0])*(a[0]-b[0]) + (a[1]-b[1])*(a[1]-b[1]);
    }

    function solve(l, r) {
        if (r - l <= 3) {
            let res = Infinity;
            for (let i = l; i < r; i++)
                for (let j = i+1; j < r; j++)
                    res = Math.min(res, dist2(pts[i], pts[j]));
            const sub = pts.slice(l, r).sort((a, b) => a[1] - b[1]);
            for (let i = l; i < r; i++) pts[i] = sub[i - l];
            return res;
        }
        const mid = (l + r) >> 1;
        const midx = pts[mid][0];
        let d = Math.min(solve(l, mid), solve(mid, r));

        const tmp = [];
        let i = l, j = mid;
        while (i < mid && j < r) {
            if (pts[i][1] <= pts[j][1]) tmp.push(pts[i++]);
            else tmp.push(pts[j++]);
        }
        while (i < mid) tmp.push(pts[i++]);
        while (j < r) tmp.push(pts[j++]);
        for (let k = 0; k < tmp.length; k++) pts[l + k] = tmp[k];

        const strip = [];
        for (let p = l; p < r; p++) {
            const dx = pts[p][0] - midx;
            if (dx * dx < d) strip.push(pts[p]);
        }
        for (let p = 0; p < strip.length; p++) {
            for (let q = p + 1; q < strip.length; q++) {
                const dy = strip[q][1] - strip[p][1];
                if (dy * dy >= d) break;
                d = Math.min(d, dist2(strip[p], strip[q]));
            }
        }
        return d;
    }

    console.log(solve(0, n));
});

复杂度分析

  • 时间复杂度(分治法);Python 随机增量法期望
  • 空间复杂度,用于存储点集、归并临时数组和 strip 数组。