题意:
给出个点的树,对于所有点对求它们在两棵树中公共的公共祖先数量之和。
题解:
不考虑求点对的贡献,考虑求祖先的贡献。
枚举一个祖先,假设两棵树上点的公共子孙个数为,那么这个点的贡献就是
难点就在如何求两棵树上点的公共子孙个数。
解法一:
求出两棵树的,发现就是求满足和的的个数。
这里就是的,就是
这不就是一个的矩阵,进行次单点加,次矩阵求和。我们可以用树套树解决。
时间复杂度:
解法二:
把和看成两维。
第一维我们用莫队,然后分块维护第二维。
时间复杂度:
#include<bits/stdc++.h> using namespace std; #define next Next #define gc getchar #define int long long const int N=1e5+5; int n,m,len,num,ans,t,rt1,rt2,id[N],size[N],ru1[N],chu1[N],ru2[N],chu2[N]; int bel[N],L[N],R[N],b[N],a[N],sum[N]; struct node{ int l,r,L,R; }q[N]; vector<int>g[N],G[N]; //char buf[1<<21],*p1=buf,*p2=buf; //inline int gc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;} inline int read() { int ret=0,f=0;char c=gc(); while(!isdigit(c)){if(c=='-')f=1;c=gc();} while(isdigit(c)){ret=ret*10+c-48;c=gc();} if(f)return -ret;return ret; } void dfs(int u,int fa) { size[u]=1; ru1[u]=++t; id[t]=u; for(int i=0;i<g[u].size();i++) { int v=g[u][i]; if(v==fa)continue; dfs(v,u); size[u]+=size[v]; } chu1[u]=t; } void dfs2(int u,int fa) { size[u]=1; ru2[u]=++t; for(int i=0;i<G[u].size();i++) { int v=G[u][i]; if(v==fa)continue; dfs2(v,u); size[u]+=size[v]; } chu2[u]=t; } bool cmp(node a,node b) { return (a.l/len)^(b.l/len)?a.l<b.l:((a.l/len)&1)?a.r<b.r:a.r>b.r; } void add(int x) { int u=id[x]; a[ru2[u]]++; sum[bel[ru2[u]]]++; } void del(int x) { int u=id[x]; a[ru2[u]]--; sum[bel[ru2[u]]]--; } int solve(int l,int r) { int res=0; if(bel[r]-bel[l]<=2) { for(int i=l;i<=r;i++)res+=a[i]; return res*(res-1)/2; } for(int i=l;i<=R[bel[l]];i++)res+=a[i]; for(int i=L[bel[r]];i<=r;i++)res+=a[i]; for(int i=bel[l]+1;i<bel[r];i++)res+=sum[i]; return res*(res-1)/2; } signed main() { n=read(); for(int i=1;i<n;i++) { int x=read(),y=read(); b[y]=1; g[x].push_back(y); } for(int i=1;i<=n;i++) if(!b[i])rt1=i; else b[i]=0; for(int i=1;i<n;i++) { int x=read(),y=read(); b[y]=1; G[x].push_back(y); } for(int i=1;i<=n;i++) if(!b[i])rt2=i; else b[i]=0; dfs(rt1,0); t=0; dfs2(rt2,0); len=sqrt(n); num=(n-1)/len+1; for(int i=1;i<=n;i++)L[i]=n+1; for(int i=1;i<=n;i++) { bel[i]=(i-1)/len+1; L[bel[i]]=min(L[bel[i]],i); R[bel[i]]=max(R[bel[i]],i); } for(int i=1;i<=n;i++) { q[i].l=ru1[i]+1; q[i].r=chu1[i]; q[i].L=ru2[i]+1; q[i].R=chu2[i]; } sort(q+1,q+n+1,cmp); int l=1,r=0; for(int i=1;i<=n;i++) { if(q[i].L>q[i].R)continue; while(l>q[i].l)add(--l); while(r<q[i].r)add(++r); while(l<q[i].l)del(l++); while(r>q[i].r)del(r--); ans+=solve(q[i].L,q[i].R); } cout<<ans; return 0; }
解法三:
我们发现第一棵树中点的子树中的点就是满足的点。
然后那些点中满足的就是要求的。
相当于我们只要用权值线段树维护即可,每次查询线段树中之间的数量。
维护子树信息,我们想到可以用线段树合并来解决。
时间复杂度:
#include<bits/stdc++.h> using namespace std; #define next Next #define gc getchar #define int long long const int N=1e5+5; int n,m,ans,t,cnt,root1,root2,b[N],id[N],size[N],ru1[N],chu1[N],ru2[N],chu2[N],rt[N]; vector<int>g[N],G[N]; struct node{ int l,r,sum; }tr[N*32]; //char buf[1<<21],*p1=buf,*p2=buf; //inline int gc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;} inline int read() { int ret=0,f=0;char c=gc(); while(!isdigit(c)){if(c=='-')f=1;c=gc();} while(isdigit(c)){ret=ret*10+c-48;c=gc();} if(f)return -ret;return ret; } void dfs(int u,int fa) { size[u]=1; ru1[u]=++t; id[t]=u; for(int i=0;i<g[u].size();i++) { int v=g[u][i]; if(v==fa)continue; dfs(v,u); size[u]+=size[v]; } chu1[u]=t; } void dfs2(int u,int fa) { size[u]=1; ru2[u]=++t; for(int i=0;i<G[u].size();i++) { int v=G[u][i]; if(v==fa)continue; dfs2(v,u); size[u]+=size[v]; } chu2[u]=t; } void pushup(int u) { tr[u].sum=tr[tr[u].l].sum+tr[tr[u].r].sum; } int merge(int a,int b,int l,int r) { if(!a)return b; if(!b)return a; if(l==r) { tr[a].sum+=tr[b].sum; return a; } int mid=(l+r)/2; tr[a].l=merge(tr[a].l,tr[b].l,l,mid); tr[a].r=merge(tr[a].r,tr[b].r,mid+1,r); pushup(a); return a; } void change(int &u,int l,int r,int x) { if(!u)u=++cnt; if(l==r) { tr[u].sum++; return; } int mid=(l+r)/2; if(x<=mid)change(tr[u].l,l,mid,x); else change(tr[u].r,mid+1,r,x); pushup(u); } int find(int u,int l,int r,int L,int R) { if(!u)return 0; if(l==L&&r==R)return tr[u].sum; int mid=(l+r)/2; if(R<=mid)return find(tr[u].l,l,mid,L,R); else if(L>mid)return find(tr[u].r,mid+1,r,L,R); else return find(tr[u].l,l,mid,L,mid)+find(tr[u].r,mid+1,r,mid+1,R); } void solve(int u,int fa) { for(int i=0;i<g[u].size();i++) { int v=g[u][i]; if(v==fa)continue; solve(v,u); rt[u]=merge(rt[u],rt[v],1,n); } change(rt[u],1,n,ru2[u]); int res=find(rt[u],1,n,ru2[u],chu2[u])-1; res=res*(res-1)/2; ans+=res; } signed main() { n=read(); for(int i=1;i<n;i++) { int x=read(),y=read(); b[y]=1; g[x].push_back(y); } for(int i=1;i<=n;i++) if(!b[i])root1=i; else b[i]=0; for(int i=1;i<n;i++) { int x=read(),y=read(); b[y]=1; G[x].push_back(y); } for(int i=1;i<=n;i++) if(!b[i])root2=i; else b[i]=0; dfs(root1,0); t=0; dfs2(root2,0); solve(root1,0); cout<<ans; return 0; }