题意:
给出个点的树,对于所有点对求它们在两棵树中公共的公共祖先数量之和。
题解:
不考虑求点对的贡献,考虑求祖先的贡献。
枚举一个祖先,假设两棵树上点
的公共子孙个数为
,那么这个点的贡献就是
难点就在如何求两棵树上点的公共子孙个数。
解法一:
求出两棵树的,发现就是求满足
和
的
的个数。
这里就是
的
,
就是
这不就是一个的矩阵,进行
次单点加,
次矩阵求和。我们可以用树套树解决。
时间复杂度:
解法二:
把和
看成两维。
第一维我们用莫队,然后分块维护第二维。
时间复杂度:
#include<bits/stdc++.h>
using namespace std;
#define next Next
#define gc getchar
#define int long long
const int N=1e5+5;
int n,m,len,num,ans,t,rt1,rt2,id[N],size[N],ru1[N],chu1[N],ru2[N],chu2[N];
int bel[N],L[N],R[N],b[N],a[N],sum[N];
struct node{
int l,r,L,R;
}q[N];
vector<int>g[N],G[N];
//char buf[1<<21],*p1=buf,*p2=buf;
//inline int gc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
inline int read()
{
int ret=0,f=0;char c=gc();
while(!isdigit(c)){if(c=='-')f=1;c=gc();}
while(isdigit(c)){ret=ret*10+c-48;c=gc();}
if(f)return -ret;return ret;
}
void dfs(int u,int fa)
{
size[u]=1;
ru1[u]=++t;
id[t]=u;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa)continue;
dfs(v,u);
size[u]+=size[v];
}
chu1[u]=t;
}
void dfs2(int u,int fa)
{
size[u]=1;
ru2[u]=++t;
for(int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if(v==fa)continue;
dfs2(v,u);
size[u]+=size[v];
}
chu2[u]=t;
}
bool cmp(node a,node b)
{
return (a.l/len)^(b.l/len)?a.l<b.l:((a.l/len)&1)?a.r<b.r:a.r>b.r;
}
void add(int x)
{
int u=id[x];
a[ru2[u]]++;
sum[bel[ru2[u]]]++;
}
void del(int x)
{
int u=id[x];
a[ru2[u]]--;
sum[bel[ru2[u]]]--;
}
int solve(int l,int r)
{
int res=0;
if(bel[r]-bel[l]<=2)
{
for(int i=l;i<=r;i++)res+=a[i];
return res*(res-1)/2;
}
for(int i=l;i<=R[bel[l]];i++)res+=a[i];
for(int i=L[bel[r]];i<=r;i++)res+=a[i];
for(int i=bel[l]+1;i<bel[r];i++)res+=sum[i];
return res*(res-1)/2;
}
signed main()
{
n=read();
for(int i=1;i<n;i++)
{
int x=read(),y=read();
b[y]=1;
g[x].push_back(y);
}
for(int i=1;i<=n;i++)
if(!b[i])rt1=i;
else b[i]=0;
for(int i=1;i<n;i++)
{
int x=read(),y=read();
b[y]=1;
G[x].push_back(y);
}
for(int i=1;i<=n;i++)
if(!b[i])rt2=i;
else b[i]=0;
dfs(rt1,0);
t=0;
dfs2(rt2,0);
len=sqrt(n);
num=(n-1)/len+1;
for(int i=1;i<=n;i++)L[i]=n+1;
for(int i=1;i<=n;i++)
{
bel[i]=(i-1)/len+1;
L[bel[i]]=min(L[bel[i]],i);
R[bel[i]]=max(R[bel[i]],i);
}
for(int i=1;i<=n;i++)
{
q[i].l=ru1[i]+1;
q[i].r=chu1[i];
q[i].L=ru2[i]+1;
q[i].R=chu2[i];
}
sort(q+1,q+n+1,cmp);
int l=1,r=0;
for(int i=1;i<=n;i++)
{
if(q[i].L>q[i].R)continue;
while(l>q[i].l)add(--l);
while(r<q[i].r)add(++r);
while(l<q[i].l)del(l++);
while(r>q[i].r)del(r--);
ans+=solve(q[i].L,q[i].R);
}
cout<<ans;
return 0;
} 解法三:
我们发现第一棵树中点的子树中的点就是满足
的点。
然后那些点中满足的就是要求的。
相当于我们只要用权值线段树维护即可,每次查询线段树中
之间的数量。
维护子树信息,我们想到可以用线段树合并来解决。
时间复杂度:
#include<bits/stdc++.h>
using namespace std;
#define next Next
#define gc getchar
#define int long long
const int N=1e5+5;
int n,m,ans,t,cnt,root1,root2,b[N],id[N],size[N],ru1[N],chu1[N],ru2[N],chu2[N],rt[N];
vector<int>g[N],G[N];
struct node{
int l,r,sum;
}tr[N*32];
//char buf[1<<21],*p1=buf,*p2=buf;
//inline int gc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
inline int read()
{
int ret=0,f=0;char c=gc();
while(!isdigit(c)){if(c=='-')f=1;c=gc();}
while(isdigit(c)){ret=ret*10+c-48;c=gc();}
if(f)return -ret;return ret;
}
void dfs(int u,int fa)
{
size[u]=1;
ru1[u]=++t;
id[t]=u;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa)continue;
dfs(v,u);
size[u]+=size[v];
}
chu1[u]=t;
}
void dfs2(int u,int fa)
{
size[u]=1;
ru2[u]=++t;
for(int i=0;i<G[u].size();i++)
{
int v=G[u][i];
if(v==fa)continue;
dfs2(v,u);
size[u]+=size[v];
}
chu2[u]=t;
}
void pushup(int u)
{
tr[u].sum=tr[tr[u].l].sum+tr[tr[u].r].sum;
}
int merge(int a,int b,int l,int r)
{
if(!a)return b;
if(!b)return a;
if(l==r)
{
tr[a].sum+=tr[b].sum;
return a;
}
int mid=(l+r)/2;
tr[a].l=merge(tr[a].l,tr[b].l,l,mid);
tr[a].r=merge(tr[a].r,tr[b].r,mid+1,r);
pushup(a);
return a;
}
void change(int &u,int l,int r,int x)
{
if(!u)u=++cnt;
if(l==r)
{
tr[u].sum++;
return;
}
int mid=(l+r)/2;
if(x<=mid)change(tr[u].l,l,mid,x);
else change(tr[u].r,mid+1,r,x);
pushup(u);
}
int find(int u,int l,int r,int L,int R)
{
if(!u)return 0;
if(l==L&&r==R)return tr[u].sum;
int mid=(l+r)/2;
if(R<=mid)return find(tr[u].l,l,mid,L,R);
else if(L>mid)return find(tr[u].r,mid+1,r,L,R);
else return find(tr[u].l,l,mid,L,mid)+find(tr[u].r,mid+1,r,mid+1,R);
}
void solve(int u,int fa)
{
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa)continue;
solve(v,u);
rt[u]=merge(rt[u],rt[v],1,n);
}
change(rt[u],1,n,ru2[u]);
int res=find(rt[u],1,n,ru2[u],chu2[u])-1;
res=res*(res-1)/2;
ans+=res;
}
signed main()
{
n=read();
for(int i=1;i<n;i++)
{
int x=read(),y=read();
b[y]=1;
g[x].push_back(y);
}
for(int i=1;i<=n;i++)
if(!b[i])root1=i;
else b[i]=0;
for(int i=1;i<n;i++)
{
int x=read(),y=read();
b[y]=1;
G[x].push_back(y);
}
for(int i=1;i<=n;i++)
if(!b[i])root2=i;
else b[i]=0;
dfs(root1,0);
t=0;
dfs2(root2,0);
solve(root1,0);
cout<<ans;
return 0;
} 
京公网安备 11010502036488号