题解:BISHI102 【模板】并查集

题目链接

【模板】并查集

题目描述

有编号为 的元素,初始时各自独立。需要支持三类操作:

  • 合并操作 1 x y: 合并元素 所在集合;
  • 查询同集 2 x y: 判断 是否同一集合;
  • 查询大小 3 x: 输出 所在集合的元素数量。

解题思路

使用并查集 (Disjoint Set Union, DSU) 维护集合:

  • 用父节点数组 与按大小合并的 ,路径压缩优化
  • 合并时将较小集合挂到较大集合根上并更新
  • 查询时比较根是否相同;大小查询返回根的

代码

#include <bits/stdc++.h>
using namespace std;

struct DSU {
    int n; vector<int> fa, sz;
    DSU(int n): n(n), fa(n+1), sz(n+1, 1) { iota(fa.begin(), fa.end(), 0); }
    int find(int x){ return fa[x]==x? x : fa[x]=find(fa[x]); }
    void unite(int a,int b){ a=find(a); b=find(b); if(a==b) return; if(sz[a]<sz[b]) swap(a,b); fa[b]=a; sz[a]+=sz[b]; }
    int size(int x){ return sz[find(x)]; }
};

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, m; if(!(cin>>n>>m)) return 0;
    DSU dsu(n);
    string out; out.reserve(m*4);
    for(int i=0;i<m;++i){
        int op; cin>>op;
        if(op==1){ int x,y; cin>>x>>y; dsu.unite(x,y); }
        else if(op==2){ int x,y; cin>>x>>y; out += (dsu.find(x)==dsu.find(y)? "YES\n":"NO\n"); }
        else { int x; cin>>x; out += to_string(dsu.size(x)); out += '\n'; }
    }
    cout<<out;
    return 0;
}
import java.io.*;

public class Main {
    static class FastScanner {
        private final InputStream in; private final byte[] buf = new byte[1<<16];
        private int p=0,l=0; FastScanner(InputStream is){in=is;}
        private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
        int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
    }

    static class DSU {
        int[] fa, sz;
        DSU(int n){ fa=new int[n+1]; sz=new int[n+1]; for(int i=0;i<=n;i++){ fa[i]=i; sz[i]=1; } }
        int find(int x){ return fa[x]==x? x: (fa[x]=find(fa[x])); }
        void unite(int a,int b){ a=find(a); b=find(b); if(a==b) return; if(sz[a]<sz[b]){int t=a;a=b;b=t;} fa[b]=a; sz[a]+=sz[b]; }
        int size(int x){ return sz[find(x)]; }
    }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        int m = fs.nextInt();
        DSU dsu = new DSU(n);
        StringBuilder out = new StringBuilder();
        for(int i=0;i<m;i++){
            int op = fs.nextInt();
            if(op==1){ int x=fs.nextInt(), y=fs.nextInt(); dsu.unite(x,y); }
            else if(op==2){ int x=fs.nextInt(), y=fs.nextInt(); out.append(dsu.find(x)==dsu.find(y)?"YES\n":"NO\n"); }
            else { int x=fs.nextInt(); out.append(dsu.size(x)).append('\n'); }
        }
        System.out.print(out.toString());
    }
}
import sys

data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); m = int(next(it))
fa = list(range(n+1))
sz = [1]*(n+1)

def find(x: int) -> int:
    while fa[x] != x:
        fa[x] = fa[fa[x]]
        x = fa[x]
    return x

out_lines = []
for _ in range(m):
    op = int(next(it))
    if op == 1:
        x = int(next(it)); y = int(next(it))
        rx, ry = find(x), find(y)
        if rx != ry:
            if sz[rx] < sz[ry]:
                rx, ry = ry, rx
            fa[ry] = rx
            sz[rx] += sz[ry]
    elif op == 2:
        x = int(next(it)); y = int(next(it))
        out_lines.append('YES' if find(x) == find(y) else 'NO')
    else:
        x = int(next(it))
        out_lines.append(str(sz[find(x)]))

sys.stdout.write('\n'.join(out_lines))

算法及复杂度

  • 算法:并查集(路径压缩 + 按大小合并),维护
  • 时间复杂度:近似
  • 空间复杂度: