Almost Union-Find

题目描述

用并查集实现如下的三种操作

1 p q : 合并元素p和q所在的集合如果p,q已经在一个集合内忽略此操作

2 p q:将元素p移动到q集合中,如果两者已经在同一集合忽略次指令

3 p:输出p所在集合的元素个数以及数值总和

样例

5 7
1 1 2
2 3 4
1 3 5
3 4
2 4 1
3 4
3 3
3 12
3 7
2 8

算法1

(用带拓展域的并查集维护可删除并查集)
  • 真正的集合的代表元素不是集合中的某个元素而是一个编号

  • 删除一个并查集中的某个元素就将他的根节点改变即可(这样不会影响其他元素的根节点)

  • 我们可以用带拓展域的并查集来维护前一半1 ~ n表示元素,n + 1 ~2 * n表示集合编号

  • 删除操作

    int pa = find(a),pb = find(b);
    if(pa != pb)
    {
        p[a] = pb;
        sum[pb] += v[a];
        sz[pb] ++;
        sz[pa] --;
        sum[pa] -= v[a];
    }

时间复杂度

参考文献

C++ 代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
// #include <unordered_map>
#include <vector>
#include <queue>
#include <set>
#include <bitset>
#include <cmath>
#include <map>

#define x first
#define y second

#define P 131

#define lc u << 1
#define rc u << 1 | 1

using namespace std;
typedef long long LL;
const int N = 200010;
int p[N];
int sz[N],v[N];
LL sum[N];
int n,q;

int find(int x)
{
    if(p[x] != x) p[x] = find(p[x]);
    return p[x];
}

inline void solve()
{
    while(scanf("%d%d",&n,&q) == 2)
    {
        for(int i = 1;i <= n * 2;i ++)
        {
            if(i <= n) p[i] = i + n,sz[i] = 0,v[i] = i;
            else p[i] = i,sz[i] = 1,v[i] = 0,sum[i] = v[i - n];
        }
        int op;
        while(q -- )
        {
            scanf("%d",&op);
            if(op == 1)
            {
                int a,b;
                scanf("%d%d",&a,&b);
                int pa = find(a),pb = find(b);
                if(pa != pb)
                {
                    p[pa] = pb;
                    sz[pb] += sz[pa];
                    sum[pb] += sum[pa];
                }
            }else if(op == 2)
            {
                int a,b;
                scanf("%d%d",&a,&b);
                int pa = find(a),pb = find(b);
                if(pa != pb)
                {
                    p[a] = pb;
                    sum[pb] += v[a];
                    sz[pb] ++;
                    sz[pa] --;
                    sum[pa] -= v[a];
                }
            }else
            {
                int a;
                scanf("%d",&a);
                int pa = find(a);
                printf("%d %lld\n",sz[pa],sum[pa]);
            }
        }
    }
}

int main()
{
    int _ = 1;
    // freopen("network.in","r",stdin);
    // freopen("network.out","w",stdout);
    // init(N - 1);
    // scanf("%d",&_);
    while(_ --)
    {
        // scanf("%lld%lld",&n,&m);
        solve();
        // test();
    }
    return 0;
}