题解:BISHI102 【模板】并查集
题目链接
题目描述
有编号为 的元素,初始时各自独立。需要支持三类操作:
- 合并操作 1 x y: 合并元素
与
所在集合;
- 查询同集 2 x y: 判断
与
是否同一集合;
- 查询大小 3 x: 输出
所在集合的元素数量。
解题思路
使用并查集 (Disjoint Set Union, DSU) 维护集合:
- 用父节点数组
与按大小合并的
,路径压缩优化
;
- 合并时将较小集合挂到较大集合根上并更新
;
- 查询时比较根是否相同;大小查询返回根的
。
代码
#include <bits/stdc++.h>
using namespace std;
struct DSU {
int n; vector<int> fa, sz;
DSU(int n): n(n), fa(n+1), sz(n+1, 1) { iota(fa.begin(), fa.end(), 0); }
int find(int x){ return fa[x]==x? x : fa[x]=find(fa[x]); }
void unite(int a,int b){ a=find(a); b=find(b); if(a==b) return; if(sz[a]<sz[b]) swap(a,b); fa[b]=a; sz[a]+=sz[b]; }
int size(int x){ return sz[find(x)]; }
};
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n, m; if(!(cin>>n>>m)) return 0;
DSU dsu(n);
string out; out.reserve(m*4);
for(int i=0;i<m;++i){
int op; cin>>op;
if(op==1){ int x,y; cin>>x>>y; dsu.unite(x,y); }
else if(op==2){ int x,y; cin>>x>>y; out += (dsu.find(x)==dsu.find(y)? "YES\n":"NO\n"); }
else { int x; cin>>x; out += to_string(dsu.size(x)); out += '\n'; }
}
cout<<out;
return 0;
}
import java.io.*;
public class Main {
static class FastScanner {
private final InputStream in; private final byte[] buf = new byte[1<<16];
private int p=0,l=0; FastScanner(InputStream is){in=is;}
private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
}
static class DSU {
int[] fa, sz;
DSU(int n){ fa=new int[n+1]; sz=new int[n+1]; for(int i=0;i<=n;i++){ fa[i]=i; sz[i]=1; } }
int find(int x){ return fa[x]==x? x: (fa[x]=find(fa[x])); }
void unite(int a,int b){ a=find(a); b=find(b); if(a==b) return; if(sz[a]<sz[b]){int t=a;a=b;b=t;} fa[b]=a; sz[a]+=sz[b]; }
int size(int x){ return sz[find(x)]; }
}
public static void main(String[] args) throws Exception {
FastScanner fs = new FastScanner(System.in);
int n = fs.nextInt();
int m = fs.nextInt();
DSU dsu = new DSU(n);
StringBuilder out = new StringBuilder();
for(int i=0;i<m;i++){
int op = fs.nextInt();
if(op==1){ int x=fs.nextInt(), y=fs.nextInt(); dsu.unite(x,y); }
else if(op==2){ int x=fs.nextInt(), y=fs.nextInt(); out.append(dsu.find(x)==dsu.find(y)?"YES\n":"NO\n"); }
else { int x=fs.nextInt(); out.append(dsu.size(x)).append('\n'); }
}
System.out.print(out.toString());
}
}
import sys
data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); m = int(next(it))
fa = list(range(n+1))
sz = [1]*(n+1)
def find(x: int) -> int:
while fa[x] != x:
fa[x] = fa[fa[x]]
x = fa[x]
return x
out_lines = []
for _ in range(m):
op = int(next(it))
if op == 1:
x = int(next(it)); y = int(next(it))
rx, ry = find(x), find(y)
if rx != ry:
if sz[rx] < sz[ry]:
rx, ry = ry, rx
fa[ry] = rx
sz[rx] += sz[ry]
elif op == 2:
x = int(next(it)); y = int(next(it))
out_lines.append('YES' if find(x) == find(y) else 'NO')
else:
x = int(next(it))
out_lines.append(str(sz[find(x)]))
sys.stdout.write('\n'.join(out_lines))
算法及复杂度
- 算法:并查集(路径压缩 + 按大小合并),维护
与
- 时间复杂度:近似
- 空间复杂度: