题目链接
题目描述
给定 个编号为
的元素,初始时每个元素都属于一个独立的集合。需要实现一个并查集数据结构,支持以下三种操作共
次:
- 合并 (Union): 将元素
和
所在的集合合并。
- 查询 (Find): 判断元素
和
是否在同一个集合中。
- 信息查询: 输出元素
所在的集合的大小(包含的元素数量)。
解题思路
并查集(Disjoint Set Union, DSU)是一种用于处理不相交集合的合并与查询问题的数据结构。它通常使用一个数组来模拟一片“森林”,其中每棵树代表一个集合。
核心结构
parent
数组:parent[i]
存储元素i
的父节点。树的根节点的父节点是它自己(即parent[root] = root
)。size
数组:size[i]
仅在i
是其所在集合的根节点时有意义,表示该集合中的元素总数。
初始化
对于 个元素,我们初始化
个集合。每个元素
i
都是一个独立的集合,因此:
parent[i] = i
size[i] = 1
关键操作与优化
-
find
(查找根节点)- 基本操作:从一个节点开始,不断沿着父节点指针向上移动,直到找到根节点(即
parent[x] == x
的节点)。 - 路径压缩 (Path Compression):这是一种重要的优化。在
find
操作找到根节点后,我们会再次从起始节点遍历到根节点,将这条路径上所有节点的parent
指针直接指向根节点。这样可以极大地“压平”树的结构,使得后续对这些节点的find
操作接近。
function find(i): if parent[i] == i: return i parent[i] = find(parent[i]) // 递归实现路径压缩 return parent[i]
- 基本操作:从一个节点开始,不断沿着父节点指针向上移动,直到找到根节点(即
-
unite
(合并集合)- 基本操作:给定两个元素
和
,首先通过
find
操作找到它们各自的根节点rootX
和rootY
。 - 如果
rootX
和rootY
相同,说明它们已经在同一个集合中,无需操作。 - 否则,将一个根节点作为另一个根节点的子节点,从而合并两棵树。
- 按大小合并 (Union by Size):这是另一个关键优化。在合并时,我们比较
rootX
和rootY
所在集合的大小,总是将较小的集合合并到较大的集合上。即,如果size[rootX] < size[rootY]
,则将rootX
的parent
指向rootY
,并更新size[rootY]
的值为size[rootY] + size[rootX]
。这可以有效避免树的深度过大,保持树的平衡。
- 基本操作:给定两个元素
操作实现
- 操作1 (合并):调用
unite(x, y)
。 - 操作2 (查询):判断
find(x) == find(y)
是否成立。 - 操作3 (信息查询):先调用
find(x)
找到根节点rootX
,然后输出size[rootX]
。
通过路径压缩和按大小合并这两种优化,并查集的每次操作的均摊时间复杂度接近常数,具体为 ,其中
是反阿克曼函数,其增长极其缓慢,对于所有实际问题都可以看作一个不超过 5 的小常数。
代码
#include <iostream>
#include <vector>
#include <numeric>
using namespace std;
vector<int> parent;
vector<int> sz;
int find_root(int i) {
if (parent[i] == i) {
return i;
}
return parent[i] = find_root(parent[i]);
}
void unite_sets(int a, int b) {
int root_a = find_root(a);
int root_b = find_root(b);
if (root_a != root_b) {
if (sz[root_a] < sz[root_b]) {
swap(root_a, root_b);
}
parent[root_b] = root_a;
sz[root_a] += sz[root_b];
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, m;
cin >> n >> m;
parent.resize(n + 1);
iota(parent.begin(), parent.end(), 0);
sz.assign(n + 1, 1);
for (int i = 0; i < m; ++i) {
int op;
cin >> op;
if (op == 1) {
int u, v;
cin >> u >> v;
unite_sets(u, v);
} else if (op == 2) {
int u, v;
cin >> u >> v;
if (find_root(u) == find_root(v)) {
cout << "YES\n";
} else {
cout << "NO\n";
}
} else {
int u;
cin >> u;
cout << sz[find_root(u)] << "\n";
}
}
return 0;
}
import java.util.Scanner;
public class Main {
private static int[] parent;
private static int[] sz;
public static int findRoot(int i) {
if (parent[i] == i) {
return i;
}
return parent[i] = findRoot(parent[i]);
}
public static void uniteSets(int a, int b) {
int rootA = findRoot(a);
int rootB = findRoot(b);
if (rootA != rootB) {
if (sz[rootA] < sz[rootB]) {
int temp = rootA;
rootA = rootB;
rootB = temp;
}
parent[rootB] = rootA;
sz[rootA] += sz[rootB];
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
parent = new int[n + 1];
sz = new int[n + 1];
for (int i = 1; i <= n; i++) {
parent[i] = i;
sz[i] = 1;
}
for (int i = 0; i < m; i++) {
int op = sc.nextInt();
if (op == 1) {
int u = sc.nextInt();
int v = sc.nextInt();
uniteSets(u, v);
} else if (op == 2) {
int u = sc.nextInt();
int v = sc.nextInt();
if (findRoot(u) == findRoot(v)) {
System.out.println("YES");
} else {
System.out.println("NO");
}
} else {
int u = sc.nextInt();
System.out.println(sz[findRoot(u)]);
}
}
}
}
import sys
# 增加递归深度限制
sys.setrecursionlimit(200005)
def find_root(i):
if parent[i] == i:
return i
parent[i] = find_root(parent[i])
return parent[i]
def unite_sets(a, b):
root_a = find_root(a)
root_b = find_root(b)
if root_a != root_b:
# 按大小合并
if sz[root_a] < sz[root_b]:
root_a, root_b = root_b, root_a
parent[root_b] = root_a
sz[root_a] += sz[root_b]
def main():
n, m = map(int, sys.stdin.readline().split())
global parent, sz
parent = list(range(n + 1))
sz = [1] * (n + 1)
for _ in range(m):
line = list(map(int, sys.stdin.readline().split()))
op = line[0]
if op == 1:
u, v = line[1], line[2]
unite_sets(u, v)
elif op == 2:
u, v = line[1], line[2]
if find_root(u) == find_root(v):
sys.stdout.write("YES\n")
else:
sys.stdout.write("NO\n")
else:
u = line[1]
sys.stdout.write(str(sz[find_root(u)]) + "\n")
if __name__ == "__main__":
main()
算法及复杂度
- 算法:带路径压缩和按大小合并的并查集 (Disjoint Set Union)
- 时间复杂度:
,其中
是元素数量,
是操作次数,
是反阿克曼函数,其值增长非常缓慢,在实践中可视为一个极小的常数。
- 空间复杂度:
,用于存储每个元素的父节点和集合大小。