并查集的主要用处是解决一类连通性问题的,遇到此类雷同问题,我们通常可以用并查集维护,往往可能在一些比较简单的问题中比dfs和bfs有更好写的优势。
并查集的代码很短,主要我们是判断连通性和uni,就是我们把两个连通块合并的操作,接下来我们会引出几个具体的函数。
具体呢,我们先讲并查集维护的思想
首先我们画个图举个例子
这大概是并查集当前的结构,我们每个点指向一个父亲节点,然后我们不断跳fa就可以知道这个连通块所在的根节点。
这样我们可以做到快速合并两个连通块,直接fa[ru]=rv就可以了,这个是O(1)的
接着我们发现这样如果是我们要查询一个连通快的根节点是非常慢的,最差的情况是这样的。
假设这是一个很长的链,我们对于较底部的节点就会产生非常大的时间复杂度,因此我们考虑如何优化,一种比较直观比较套路的思路是我们启发式合并,这样复杂度肯定是对的,因此我们就有了一个log的按秩合并。
但是我们有没有复杂度更优而且更好些更便捷的方式呢?
我们在引入一个路径压缩的思路
也就是我们初始是一条链,我们每次在访问一个节点的时候我们递归把一路上的节点都连到Root上,
再画张图解释一下。
这样可以大幅度减少我们的复杂度,然后大概是O(n)+一个小常数,我们可以把这个常数忽略不计。
我们给一下代码:
struct Dsu{ int fa[N]; void init(int n){for(int i=1;i<=n;i++)fa[i]=i;} int find(int u){return fa[u]==u?u:fa[u]=find(fa[u]);} void uni(int u,int v){if(find(u)!=find(v))fa[fa[u]]=fa[v];} }t;
还有一个加上按秩合并的代码:
struct Dsu{ int fa[N],siz[N]; void init(int n){for(int i=1;i<=n;i++)fa[i]=i,siz[i]=1;} int find(int u){return fa[u]==u?u:fa[u]=find(fa[u]);} void uni(int u,int v){ int ru=find(u),rv=find(v); if(ru==rv)return; if(siz[ru]<siz[rv])swap(ru,rv); siz[ru]+=siz[rv];fa[rv]=ru; } }t;
(其实按秩合并在这里的用处不大,因为一直有了路径压缩,所以我们只需要用最上面的写法就可以,而且简便)
下面我们来看看并查集的几道简单题
洛谷P1111 修复公路
https://www.luogu.com.cn/problem/P1111
用并查集维护连通性,然后跑一个最小生成树
代码:
#include<bits/stdc++.h> #define LL long long using namespace std; const int N=1e5+5; struct graph{int u,v,t;}gra[N]; int n,m,fa[N],cnt; bool cmp(graph a,graph b){return a.t<b.t;} int find(int u){return fa[u]==u?u:fa[u]=find(fa[u]);} int main(){ scanf("%d%d",&n,&m); for(int i=1;i<=m;i++)scanf("%d%d%d",&gra[i].u,&gra[i].v,&gra[i].t); sort(gra+1,gra+m+1,cmp); for(int i=1;i<=n;i++)fa[i]=i; for(int i=1,u,v;i<=m;i++){ u=gra[i].u;v=gra[i].v; if(find(u)!=find(v)){ fa[fa[u]]=fa[v];cnt++; if(cnt==n-1){printf("%d\n",gra[i].t);return 0;} } } puts("-1"); return 0; }
洛谷P3958 奶酪
https://www.luogu.com.cn/problem/P3958
其实这道题其他的方法也可以做,诸如bfs,dfs等,但是并查集可能是最方便的,不用建图,直接判断就可以了,是个平方做法。
代码:
#include<bits/stdc++.h> #define LL long long #define db double using namespace std; const int N=1e3+5; struct node{int x,y,z;}a[N]; int n,h,r,s,t,fa[N]; int find(int u){return fa[u]==u?u:fa[u]=find(fa[u]);} void uni(int u,int v){if(find(u)!=find(v))fa[fa[u]]=fa[v];} db sqr(int x){return (db)x*x;} int main(){ int T;scanf("%d",&T); while(T--){ scanf("%d%d%d",&n,&h,&r);s=n+1;t=n+2; for(int i=1;i<=t;i++)fa[i]=i; for(int i=1;i<=n;i++)scanf("%d%d%d",&a[i].x,&a[i].y,&a[i].z); for(int i=1;i<=n;i++){ if(a[i].z<=r)uni(s,i); if(a[i].z>=h-r)uni(i,t); } for(int i=1;i<=n;i++)for(int j=i+1;j<=n;j++){ db d=sqrt(sqr(a[i].x-a[j].x)+sqr(a[i].y-a[j].y)+sqr(a[i].z-a[j].z)); if(d<=r*2.0)uni(i,j); } puts(find(s)==find(t)?"Yes":"No"); } return 0; }
洛谷P4185 [USACO18JAN]MooTube G
https://www.luogu.com.cn/problem/P4185
这题是一个经典的套路,我们把建边和查询放在一起,然后struct里面sort
然后动态维护siz,我们就可以离线下来算出答案
代码:
#include<bits/stdc++.h> #define LL long long using namespace std; const int N=2e5+5; struct node{int u,v,w;}e[N]; int n,q,cnt,fa[N],sz[N],ans[N]; int find(int u){return fa[u]==u?u:fa[u]=find(fa[u]);} int main(){ scanf("%d%d",&n,&q); for(int i=1;i<n;i++)scanf("%d%d%d",&e[i].u,&e[i].v,&e[i].w); cnt=n-1;for(int i=1;i<=n;i++)fa[i]=i,sz[i]=1; for(int i=1,u,w;i<=q;i++)scanf("%d%d",&w,&u),e[++cnt]=(node){u,-i,w}; sort(e+1,e+cnt+1,[&](node a,node b){return (a.w!=b.w?a.w>b.w:a.v>b.v);}); //for(int i=1;i<=cnt;i++)printf("%d %d %d\n",e[i].u,e[i].v,e[i].w); for(int i=1;i<=cnt;i++)if(e[i].v>0){ int u=e[i].u,v=e[i].v,ru=find(u),rv=find(v); if(ru!=rv)sz[rv]+=sz[ru],fa[ru]=rv; }else ans[-e[i].v]=sz[find(e[i].u)]-1; for(int i=1;i<=q;i++)printf("%d\n",ans[i]); return 0; }
洛谷P1197 [JSOI2008]星球大战
https://www.luogu.com.cn/problem/P1197
还是一个很常见的套路,对于删边,删点的问题,我们如果直接连通性可能要用LCT这样的数据结构维护,但是往往会码量翻倍。
我们可以使用时间倒流的方法维护,是一种离线算法,每次把所有的询问翻转,然后倒着维护,这样就变成了加点加边的操作,我们可以快速用并查集操作!
代码:
#include<bits/stdc++.h> #define LL long long #define pb push_back using namespace std; const int N=4e5+5; struct Grpah{int u,v;}g[N]; int n,m,k,fa[N],tot,atk[N];bool usd[N]; vector<int>adj[N],Ans; int find(int u){return fa[u]==u?u:fa[u]=find(fa[u]);} void uni(int u,int v){if(find(u)!=find(v))fa[fa[u]]=fa[v],tot--;} int main(){ scanf("%d%d",&n,&m); for(int i=1,u,v;i<=m;i++){ scanf("%d%d",&u,&v);u++;v++; adj[u].pb(v);adj[v].pb(u); g[i]=(Grpah){u,v}; } scanf("%d",&k); for(int i=1;i<=k;i++)scanf("%d",&atk[i]),usd[++atk[i]]=true; tot=n-k;for(int i=1;i<=n;i++)fa[i]=i; for(int i=1;i<=m;i++)if(!usd[g[i].u]&&!usd[g[i].v])uni(g[i].u,g[i].v); Ans.pb(tot); for(int i=k;i;i--){ usd[atk[i]]=false;tot++; for(auto v:adj[atk[i]])if(!usd[v])uni(atk[i],v); Ans.pb(tot); } reverse(Ans.begin(),Ans.end()); for(auto x:Ans)printf("%d\n",x); return 0; }
洛谷P1955 [NOI2015]程序自动分析
https://www.luogu.com.cn/problem/P1955
这题对于等号不等号之间的处理我们或许可以用带权并查集
但是本题我们可以观察到只有两种操作,我们可以事先把所有等号的连接,然后之后所有的不等我们不能在同一个连通快,如果满足这样的关系就是T,反之F
代码:
#include<bits/stdc++.h> #define LL long long using namespace std; const int N=2e6+5; struct Graph{int u,v,c;}g[N]; int fa[N],id;map<int,int>mp; int find(int u){return fa[u]==u?u:fa[u]=find(fa[u]);} int main(){ int T;scanf("%d",&T); while(T--){ int n;scanf("%d",&n); mp.clear();id=0; for(int i=1,u,v,c;i<=n;i++){ scanf("%d%d%d",&u,&v,&c); if(!mp[u])mp[u]=++id; if(!mp[v])mp[v]=++id; u=mp[u];v=mp[v]; g[i]=(Graph){u,v,c}; } for(int i=1;i<=id;i++)fa[i]=i; for(int i=1,u,v,c;i<=n;i++){ u=g[i].u;v=g[i].v;c=g[i].c; if(c==1){if(find(u)!=find(v))fa[fa[u]]=fa[v];} } bool fl=true; for(int i=1,u,v,c;i<=n;i++){ u=g[i].u;v=g[i].v;c=g[i].c; if(!c){if(find(u)==find(v)){fl=false;break;}} } puts(fl?"YES":"NO"); } return 0; }
接下来简单介绍一下带权并查集的思路和核心
我们和一般并查集的结构是一样的,图里是一个连通快,然后只是我们每个点指向它的fa有一个边权,然后我们带权只需要快速求出到root的边权,就能处理两两之间的关系,这种关系通常是可加的,例如权值加法,异或等等。
洛谷P2024 [NOI2001]食物链
https://www.luogu.com.cn/problem/P2024
我们考虑题目中给出的食物链结构肯定是一个三元环,我们可以考虑用带权并查集,维护一个对3取余的模数,然后判断彼此的关系就可以了。
然而这道题仅仅是一个三元环,可以提示我们其实不需要怎么做,可以不用使用带权并查集,因为环上的点仅为3,我们可以发现任意两个点总是有关系的,类似于石头剪刀布,你们我们确定了两个点的关系以后,把它们的前驱后继的关系也确定就可以了,这里我们对于拆3个点,然后连边。
代码:
#include<bits/stdc++.h> #define LL long long using namespace std; const int N=5e5+5; int n,q,fa[N],ans; int find(int u){return fa[u]==u?u:fa[u]=find(fa[u]);} void uni(int u,int v){if(find(u)!=find(v))fa[fa[u]]=fa[v];} int main(){ scanf("%d%d",&n,&q); for(int i=1;i<=n*3;i++)fa[i]=i; while(q--){ int opt,x,y;scanf("%d%d%d",&opt,&x,&y); if(x>n||y>n){ans++;continue;} if(opt==1){ if(find(x+n)==find(y)||find(x+(n<<1))==find(y))ans++; else uni(x,y),uni(x+n,y+n),uni(x+(n<<1),y+(n<<1)); }else{ if(x==y||find(x)==find(y)||find(x+(n<<1))==find(y))ans++; else uni(x,y+(n<<1)),uni(x+n,y),uni(x+(n<<1),y+n); } } printf("%d\n",ans); return 0; }
洛谷P1196 [NOI2002]银河英雄传说
https://www.luogu.com.cn/problem/P1196
这题是一个带权并查集的裸题,我们通过这道题来介绍一下带权并查集的具体写法,
这题我们对于每个点要维护一下它和它在并查集上的父亲节点的距离,我们front数组,那么连接的时候很简单,我们只需要每次都维护一个siz,然后连接的时候front加一下siz就可以
既然有了siz,其实我们已经得出了带权并查集的log做法,就是我们不使用路径压缩,然后我们只是按秩合并。
但是实际上带权并查集也是可以路径压缩的,关键在于我们的路径压缩怎么维护front呢?
给出以下find的路径压缩的带权并查集代码:
int find(int u){ if(fa[u]==u)return u; int ff=find(fa[u]); front[u]+=front[fa[u]]; return fa[u]=ff; }
我们先调用一个find然后存在ff里面,先递归它,然后我们在递归结束已经得到fa的front数组的时候再加一下,最后连接就可以了。
最后看一下本题的完整解法
代码:
#include<bits/stdc++.h> #define LL long long using namespace std; const int N=3e4+5; int fa[N],siz[N],front[N]; int find(int u){ if(fa[u]==u)return u; int ff=find(fa[u]); front[u]+=front[fa[u]]; return fa[u]=ff; } int main(){ int T;scanf("%d",&T); for(int i=1;i<N;i++)fa[i]=i,siz[i]=1; while(T--){ char opt[5];int x,y,fx,fy; scanf("%s%d%d",opt,&x,&y); fx=find(x);fy=find(y); if(opt[0]=='M'){ front[fx]+=siz[fy]; siz[fy]+=siz[fx]; fa[fx]=fy; }else{ if(fx!=fy)puts("-1"); else printf("%d\n",abs(front[x]-front[y])-1); } } return 0; }
洛谷P1525 关押罪犯
https://www.luogu.com.cn/problem/P1525
这题实际上我们可以用带权并查集维护
比较显然的思路是我们能加就加,然后知道加出问题之后就停止,我们在这个加的过程中用带权并查集维护01关系,也就是一个二分图就可以。
但是实际上这题我们还可以二分+二分图判断,显然更好理解,然后就不用并查集了
代码:
#include<bits/stdc++.h> #define LL long long #define pb push_back using namespace std; const int N=1e5+5; struct Graph{int u,v,w;}g[N]; int n,m,col[N];bool fl; vector<int>adj[N]; bool cmp(Graph a,Graph b){return a.w>b.w;} void dfs(int u,int c){ col[u]=c; for(auto v:adj[u]){ if(~col[v]&&col[v]!=(c^1)){fl=false;return;} if(col[v]==-1)dfs(v,c^1); } } bool check(int t){ for(int i=1;i<=n;i++)adj[i].clear(); for(int i=1,u,v;i<=t;i++) u=g[i].u,v=g[i].v,adj[u].pb(v),adj[v].pb(u); memset(col,-1,sizeof(col)); fl=true; for(int i=1;i<=n;i++)if(col[i]==-1)dfs(i,0); return fl; } int main(){ scanf("%d%d",&n,&m); for(int i=1;i<=m;i++)scanf("%d%d%d",&g[i].u,&g[i].v,&g[i].w); sort(g+1,g+m+1,cmp); int l=1; for(int k=20,t;~k;k--){ t=l+(1<<k); if(t<=m&&check(t))l=t; } printf("%d\n",g[l+1].w); return 0; }