题目链接

【模板】并查集

题目描述

给定 个编号为 的元素,初始时每个元素都属于一个独立的集合。需要实现一个并查集数据结构,支持以下三种操作共 次:

  1. 合并 (Union): 将元素 所在的集合合并。
  2. 查询 (Find): 判断元素 是否在同一个集合中。
  3. 信息查询: 输出元素 所在的集合的大小(包含的元素数量)。

解题思路

并查集(Disjoint Set Union, DSU)是一种用于处理不相交集合的合并与查询问题的数据结构。它通常使用一个数组来模拟一片“森林”,其中每棵树代表一个集合。

核心结构

  • parent 数组:parent[i] 存储元素 i 的父节点。树的根节点的父节点是它自己(即 parent[root] = root)。
  • size 数组:size[i] 仅在 i 是其所在集合的根节点时有意义,表示该集合中的元素总数。

初始化

对于 个元素,我们初始化 个集合。每个元素 i 都是一个独立的集合,因此:

  • parent[i] = i
  • size[i] = 1

关键操作与优化

  1. find (查找根节点)

    • 基本操作:从一个节点开始,不断沿着父节点指针向上移动,直到找到根节点(即 parent[x] == x 的节点)。
    • 路径压缩 (Path Compression):这是一种重要的优化。在 find 操作找到根节点后,我们会再次从起始节点遍历到根节点,将这条路径上所有节点的 parent 指针直接指向根节点。这样可以极大地“压平”树的结构,使得后续对这些节点的 find 操作接近
    function find(i):
        if parent[i] == i:
            return i
        parent[i] = find(parent[i]) // 递归实现路径压缩
        return parent[i]
    
  2. unite (合并集合)

    • 基本操作:给定两个元素 ,首先通过 find 操作找到它们各自的根节点 rootXrootY
    • 如果 rootXrootY 相同,说明它们已经在同一个集合中,无需操作。
    • 否则,将一个根节点作为另一个根节点的子节点,从而合并两棵树。
    • 按大小合并 (Union by Size):这是另一个关键优化。在合并时,我们比较 rootXrootY 所在集合的大小,总是将较小的集合合并到较大的集合上。即,如果 size[rootX] < size[rootY],则将 rootXparent 指向 rootY,并更新 size[rootY] 的值为 size[rootY] + size[rootX]。这可以有效避免树的深度过大,保持树的平衡。

操作实现

  • 操作1 (合并):调用 unite(x, y)
  • 操作2 (查询):判断 find(x) == find(y) 是否成立。
  • 操作3 (信息查询):先调用 find(x) 找到根节点 rootX,然后输出 size[rootX]

通过路径压缩和按大小合并这两种优化,并查集的每次操作的均摊时间复杂度接近常数,具体为 ,其中 是反阿克曼函数,其增长极其缓慢,对于所有实际问题都可以看作一个不超过 5 的小常数。

代码

#include <iostream>
#include <vector>
#include <numeric>

using namespace std;

vector<int> parent;
vector<int> sz;

int find_root(int i) {
    if (parent[i] == i) {
        return i;
    }
    return parent[i] = find_root(parent[i]);
}

void unite_sets(int a, int b) {
    int root_a = find_root(a);
    int root_b = find_root(b);
    if (root_a != root_b) {
        if (sz[root_a] < sz[root_b]) {
            swap(root_a, root_b);
        }
        parent[root_b] = root_a;
        sz[root_a] += sz[root_b];
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;

    parent.resize(n + 1);
    iota(parent.begin(), parent.end(), 0);
    sz.assign(n + 1, 1);

    for (int i = 0; i < m; ++i) {
        int op;
        cin >> op;
        if (op == 1) {
            int u, v;
            cin >> u >> v;
            unite_sets(u, v);
        } else if (op == 2) {
            int u, v;
            cin >> u >> v;
            if (find_root(u) == find_root(v)) {
                cout << "YES\n";
            } else {
                cout << "NO\n";
            }
        } else {
            int u;
            cin >> u;
            cout << sz[find_root(u)] << "\n";
        }
    }

    return 0;
}
import java.util.Scanner;

public class Main {
    private static int[] parent;
    private static int[] sz;

    public static int findRoot(int i) {
        if (parent[i] == i) {
            return i;
        }
        return parent[i] = findRoot(parent[i]);
    }

    public static void uniteSets(int a, int b) {
        int rootA = findRoot(a);
        int rootB = findRoot(b);
        if (rootA != rootB) {
            if (sz[rootA] < sz[rootB]) {
                int temp = rootA;
                rootA = rootB;
                rootB = temp;
            }
            parent[rootB] = rootA;
            sz[rootA] += sz[rootB];
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();

        parent = new int[n + 1];
        sz = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            parent[i] = i;
            sz[i] = 1;
        }

        for (int i = 0; i < m; i++) {
            int op = sc.nextInt();
            if (op == 1) {
                int u = sc.nextInt();
                int v = sc.nextInt();
                uniteSets(u, v);
            } else if (op == 2) {
                int u = sc.nextInt();
                int v = sc.nextInt();
                if (findRoot(u) == findRoot(v)) {
                    System.out.println("YES");
                } else {
                    System.out.println("NO");
                }
            } else {
                int u = sc.nextInt();
                System.out.println(sz[findRoot(u)]);
            }
        }
    }
}
import sys

# 增加递归深度限制
sys.setrecursionlimit(200005)

def find_root(i):
    if parent[i] == i:
        return i
    parent[i] = find_root(parent[i])
    return parent[i]

def unite_sets(a, b):
    root_a = find_root(a)
    root_b = find_root(b)
    if root_a != root_b:
        # 按大小合并
        if sz[root_a] < sz[root_b]:
            root_a, root_b = root_b, root_a
        parent[root_b] = root_a
        sz[root_a] += sz[root_b]

def main():
    n, m = map(int, sys.stdin.readline().split())
    
    global parent, sz
    parent = list(range(n + 1))
    sz = [1] * (n + 1)

    for _ in range(m):
        line = list(map(int, sys.stdin.readline().split()))
        op = line[0]
        
        if op == 1:
            u, v = line[1], line[2]
            unite_sets(u, v)
        elif op == 2:
            u, v = line[1], line[2]
            if find_root(u) == find_root(v):
                sys.stdout.write("YES\n")
            else:
                sys.stdout.write("NO\n")
        else:
            u = line[1]
            sys.stdout.write(str(sz[find_root(u)]) + "\n")

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:带路径压缩和按大小合并的并查集 (Disjoint Set Union)
  • 时间复杂度,其中 是元素数量, 是操作次数, 是反阿克曼函数,其值增长非常缓慢,在实践中可视为一个极小的常数。
  • 空间复杂度,用于存储每个元素的父节点和集合大小。