小红的星屑共鸣
题目分析
给定 个二维平面上的点,求所有点对之间欧几里得距离的平方的最小值。
这就是经典的最近点对问题,只不过输出的是距离的平方而非距离本身。
思路
分治法(Closest Pair of Points)
暴力枚举所有点对的时间复杂度为 ,对于
较大的情况无法通过。经典做法是分治:
- 预处理:将所有点按
坐标排序。
- 分治:取中间点的
坐标作为分割线,将点集分为左右两半,分别递归求解左半和右半的最近点对距离
。
- 合并(关键步骤):最近点对可能一个在左半、一个在右半。只需考虑到分割线距离小于
的点(即"strip"区域)。将 strip 中的点按
坐标排序后,对每个点只需检查
坐标差小于
的后续点。数学上可以证明,每个点最多只需检查常数个(不超过 7 个)候选点,因此合并步骤是线性的。
- 归并排序优化:为避免每层都对 strip 重新按
排序,在递归过程中同时完成按
坐标的归并排序(类似归并排序的 merge 步骤),使总复杂度保持
。
> 注意:由于我们比较的是距离的平方,在 strip 筛选时用 代替
,避免浮点运算。
Python 的特殊处理:Python 递归较慢,分治法容易超时。改用随机增量法 + 网格哈希:随机打乱点的顺序,维护边长为 的网格,每插入一个新点只需检查周围
的网格。当最近距离更新时重建网格。期望时间复杂度
。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pll;
ll dist2(const pll& a, const pll& b){
return (a.first-b.first)*(a.first-b.first)+(a.second-b.second)*(a.second-b.second);
}
ll solve(vector<pll>& pts, int l, int r){
if(r - l <= 3){
ll res = LLONG_MAX;
for(int i = l; i < r; i++)
for(int j = i+1; j < r; j++)
res = min(res, dist2(pts[i], pts[j]));
sort(pts.begin()+l, pts.begin()+r, [](const pll& a, const pll& b){ return a.second < b.second; });
return res;
}
int mid = (l+r)/2;
ll midx = pts[mid].first;
ll d = min(solve(pts, l, mid), solve(pts, mid, r));
// merge two halves sorted by y
vector<pll> tmp(r-l);
int i = l, j = mid, k = 0;
while(i < mid && j < r){
if(pts[i].second <= pts[j].second) tmp[k++] = pts[i++];
else tmp[k++] = pts[j++];
}
while(i < mid) tmp[k++] = pts[i++];
while(j < r) tmp[k++] = pts[j++];
copy(tmp.begin(), tmp.end(), pts.begin()+l);
// strip
vector<pll> strip;
for(int i = l; i < r; i++){
ll dx = pts[i].first - midx;
if(dx * dx < d)
strip.push_back(pts[i]);
}
for(int i = 0; i < (int)strip.size(); i++){
for(int j = i+1; j < (int)strip.size(); j++){
ll dy = strip[j].second - strip[i].second;
if(dy * dy >= d) break;
d = min(d, dist2(strip[i], strip[j]));
}
}
return d;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
vector<pll> pts(n);
for(int i = 0; i < n; i++) cin >> pts[i].first >> pts[i].second;
sort(pts.begin(), pts.end());
cout << solve(pts, 0, n) << "\n";
return 0;
}
import java.util.*;
import java.io.*;
public class Main {
static long dist2(long[] a, long[] b) {
return (a[0]-b[0])*(a[0]-b[0]) + (a[1]-b[1])*(a[1]-b[1]);
}
static long solve(long[][] pts, int l, int r) {
if (r - l <= 3) {
long res = Long.MAX_VALUE;
for (int i = l; i < r; i++)
for (int j = i+1; j < r; j++)
res = Math.min(res, dist2(pts[i], pts[j]));
Arrays.sort(pts, l, r, (a, b) -> Long.compare(a[1], b[1]));
return res;
}
int mid = (l + r) / 2;
long midx = pts[mid][0];
long d = Math.min(solve(pts, l, mid), solve(pts, mid, r));
long[][] tmp = new long[r - l][2];
int i = l, j = mid, k = 0;
while (i < mid && j < r) {
if (pts[i][1] <= pts[j][1]) tmp[k++] = pts[i++];
else tmp[k++] = pts[j++];
}
while (i < mid) tmp[k++] = pts[i++];
while (j < r) tmp[k++] = pts[j++];
System.arraycopy(tmp, 0, pts, l, r - l);
List<long[]> strip = new ArrayList<>();
for (int p = l; p < r; p++) {
long dx = pts[p][0] - midx;
if (dx * dx < d) strip.add(pts[p]);
}
for (int p = 0; p < strip.size(); p++) {
for (int q = p + 1; q < strip.size(); q++) {
long dy = strip.get(q)[1] - strip.get(p)[1];
if (dy * dy >= d) break;
d = Math.min(d, dist2(strip.get(p), strip.get(q)));
}
}
return d;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine().trim());
long[][] pts = new long[n][2];
for (int i = 0; i < n; i++) {
StringTokenizer st = new StringTokenizer(br.readLine());
pts[i][0] = Long.parseLong(st.nextToken());
pts[i][1] = Long.parseLong(st.nextToken());
}
Arrays.sort(pts, (a, b) -> Long.compare(a[0], b[0]));
System.out.println(solve(pts, 0, n));
}
}
import sys
from random import shuffle
def solve():
input_data = sys.stdin.buffer.read().split()
idx = 0
n = int(input_data[idx]); idx += 1
pts = []
for i in range(n):
x = int(input_data[idx]); idx += 1
y = int(input_data[idx]); idx += 1
pts.append((x, y))
if n <= 1:
print(0)
return
shuffle(pts)
def dist2(a, b):
return (a[0]-b[0])**2 + (a[1]-b[1])**2
d = dist2(pts[0], pts[1])
if d == 0:
print(0)
return
from math import isqrt
def make_grid(pts_list, sz):
g = {}
for p in pts_list:
gx = p[0] // sz
gy = p[1] // sz
if (gx, gy) not in g:
g[(gx, gy)] = []
g[(gx, gy)].append(p)
return g
sz = max(1, isqrt(d))
grid = make_grid(pts[:2], sz)
for i in range(2, n):
p = pts[i]
gx = p[0] // sz
gy = p[1] // sz
mind = d
for dx in range(-2, 3):
for dy in range(-2, 3):
cell = grid.get((gx+dx, gy+dy))
if cell:
for q in cell:
dd = dist2(p, q)
if dd < mind:
mind = dd
if mind < d:
d = mind
if d == 0:
print(0)
return
sz = max(1, isqrt(d))
grid = make_grid(pts[:i+1], sz)
else:
if (gx, gy) not in grid:
grid[(gx, gy)] = []
grid[(gx, gy)].append(p)
print(d)
solve()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l));
rl.on('close', () => {
const n = parseInt(lines[0]);
const pts = [];
for (let i = 1; i <= n; i++) {
const [x, y] = lines[i].split(' ').map(Number);
pts.push([x, y]);
}
pts.sort((a, b) => a[0] - b[0] || a[1] - b[1]);
function dist2(a, b) {
return (a[0]-b[0])*(a[0]-b[0]) + (a[1]-b[1])*(a[1]-b[1]);
}
function solve(l, r) {
if (r - l <= 3) {
let res = Infinity;
for (let i = l; i < r; i++)
for (let j = i+1; j < r; j++)
res = Math.min(res, dist2(pts[i], pts[j]));
const sub = pts.slice(l, r).sort((a, b) => a[1] - b[1]);
for (let i = l; i < r; i++) pts[i] = sub[i - l];
return res;
}
const mid = (l + r) >> 1;
const midx = pts[mid][0];
let d = Math.min(solve(l, mid), solve(mid, r));
const tmp = [];
let i = l, j = mid;
while (i < mid && j < r) {
if (pts[i][1] <= pts[j][1]) tmp.push(pts[i++]);
else tmp.push(pts[j++]);
}
while (i < mid) tmp.push(pts[i++]);
while (j < r) tmp.push(pts[j++]);
for (let k = 0; k < tmp.length; k++) pts[l + k] = tmp[k];
const strip = [];
for (let p = l; p < r; p++) {
const dx = pts[p][0] - midx;
if (dx * dx < d) strip.push(pts[p]);
}
for (let p = 0; p < strip.length; p++) {
for (let q = p + 1; q < strip.length; q++) {
const dy = strip[q][1] - strip[p][1];
if (dy * dy >= d) break;
d = Math.min(d, dist2(strip[p], strip[q]));
}
}
return d;
}
console.log(solve(0, n));
});
复杂度分析
- 时间复杂度:
(分治法);Python 随机增量法期望
。
- 空间复杂度:
,用于存储点集、归并临时数组和 strip 数组。

京公网安备 11010502036488号