题意:

给出个点的树,对于所有点对求它们在两棵树中公共的公共祖先数量之和。

题解:

不考虑求点对的贡献,考虑求祖先的贡献。

枚举一个祖先,假设两棵树上点的公共子孙个数为,那么这个点的贡献就是

难点就在如何求两棵树上点的公共子孙个数。

解法一:

求出两棵树的,发现就是求满足的个数。
这里就是就是

这不就是一个的矩阵,进行次单点加,次矩阵求和。我们可以用树套树解决。

时间复杂度:

解法二:

看成两维。

第一维我们用莫队,然后分块维护第二维。

时间复杂度:

#include<bits/stdc++.h>
using namespace std;
#define next Next
#define gc getchar
#define int long long
const int N=1e5+5;
int n,m,len,num,ans,t,rt1,rt2,id[N],size[N],ru1[N],chu1[N],ru2[N],chu2[N];
int bel[N],L[N],R[N],b[N],a[N],sum[N];
struct node{
    int l,r,L,R;
}q[N];
vector<int>g[N],G[N];
//char buf[1<<21],*p1=buf,*p2=buf;
//inline int gc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
inline int read()
{
    int ret=0,f=0;char c=gc();
    while(!isdigit(c)){if(c=='-')f=1;c=gc();}
    while(isdigit(c)){ret=ret*10+c-48;c=gc();}
    if(f)return -ret;return ret;
}
void dfs(int u,int fa)
{
    size[u]=1;
    ru1[u]=++t;
    id[t]=u;
    for(int i=0;i<g[u].size();i++)
    {
        int v=g[u][i];
        if(v==fa)continue;
        dfs(v,u);
        size[u]+=size[v];
    }
    chu1[u]=t;
}
void dfs2(int u,int fa)
{
    size[u]=1;
    ru2[u]=++t;
    for(int i=0;i<G[u].size();i++)
    {
        int v=G[u][i];
        if(v==fa)continue;
        dfs2(v,u);
        size[u]+=size[v];
    }
    chu2[u]=t;
}
bool cmp(node a,node b)
{
    return (a.l/len)^(b.l/len)?a.l<b.l:((a.l/len)&1)?a.r<b.r:a.r>b.r;
}
void add(int x)
{
    int u=id[x];
    a[ru2[u]]++;
    sum[bel[ru2[u]]]++;
}
void del(int x)
{
    int u=id[x];
    a[ru2[u]]--;
    sum[bel[ru2[u]]]--;
}
int solve(int l,int r)
{
    int res=0;
    if(bel[r]-bel[l]<=2)
    {
        for(int i=l;i<=r;i++)res+=a[i];
        return res*(res-1)/2;
    }
    for(int i=l;i<=R[bel[l]];i++)res+=a[i];
    for(int i=L[bel[r]];i<=r;i++)res+=a[i];
    for(int i=bel[l]+1;i<bel[r];i++)res+=sum[i];
    return res*(res-1)/2;
}
signed main()
{
    n=read();
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        b[y]=1;
        g[x].push_back(y);
    }
    for(int i=1;i<=n;i++)
        if(!b[i])rt1=i;
        else b[i]=0;
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        b[y]=1;
        G[x].push_back(y);
    }
    for(int i=1;i<=n;i++)
        if(!b[i])rt2=i;
        else b[i]=0;
    dfs(rt1,0);
    t=0;
    dfs2(rt2,0);
    len=sqrt(n);
    num=(n-1)/len+1;
    for(int i=1;i<=n;i++)L[i]=n+1;
    for(int i=1;i<=n;i++)
    {
        bel[i]=(i-1)/len+1;
        L[bel[i]]=min(L[bel[i]],i);
        R[bel[i]]=max(R[bel[i]],i);
    }
    for(int i=1;i<=n;i++)
    {
        q[i].l=ru1[i]+1;
        q[i].r=chu1[i];
        q[i].L=ru2[i]+1;
        q[i].R=chu2[i];
    }
    sort(q+1,q+n+1,cmp);
    int l=1,r=0;
    for(int i=1;i<=n;i++)
    {
        if(q[i].L>q[i].R)continue;
        while(l>q[i].l)add(--l);
        while(r<q[i].r)add(++r);
        while(l<q[i].l)del(l++);
        while(r>q[i].r)del(r--);
        ans+=solve(q[i].L,q[i].R);
    }
    cout<<ans;
    return 0;
}

解法三:

我们发现第一棵树中点的子树中的点就是满足的点。

然后那些点中满足的就是要求的。

相当于我们只要用权值线段树维护即可,每次查询线段树中之间的数量。

维护子树信息,我们想到可以用线段树合并来解决。

时间复杂度:

#include<bits/stdc++.h>
using namespace std;
#define next Next
#define gc getchar
#define int long long
const int N=1e5+5;
int n,m,ans,t,cnt,root1,root2,b[N],id[N],size[N],ru1[N],chu1[N],ru2[N],chu2[N],rt[N];
vector<int>g[N],G[N];
struct node{
    int l,r,sum;
}tr[N*32];
//char buf[1<<21],*p1=buf,*p2=buf;
//inline int gc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
inline int read()
{
    int ret=0,f=0;char c=gc();
    while(!isdigit(c)){if(c=='-')f=1;c=gc();}
    while(isdigit(c)){ret=ret*10+c-48;c=gc();}
    if(f)return -ret;return ret;
}
void dfs(int u,int fa)
{
    size[u]=1;
    ru1[u]=++t;
    id[t]=u;
    for(int i=0;i<g[u].size();i++)
    {
        int v=g[u][i];
        if(v==fa)continue;
        dfs(v,u);
        size[u]+=size[v];
    }
    chu1[u]=t;
}
void dfs2(int u,int fa)
{
    size[u]=1;
    ru2[u]=++t;
    for(int i=0;i<G[u].size();i++)
    {
        int v=G[u][i];
        if(v==fa)continue;
        dfs2(v,u);
        size[u]+=size[v];
    }
    chu2[u]=t;
}
void pushup(int u)
{
    tr[u].sum=tr[tr[u].l].sum+tr[tr[u].r].sum;
}
int merge(int a,int b,int l,int r)
{
    if(!a)return b;
    if(!b)return a;
    if(l==r)
    {
        tr[a].sum+=tr[b].sum;
        return a;
    }
    int mid=(l+r)/2;
    tr[a].l=merge(tr[a].l,tr[b].l,l,mid);
    tr[a].r=merge(tr[a].r,tr[b].r,mid+1,r);
    pushup(a);
    return a;
}
void change(int &u,int l,int r,int x)
{
    if(!u)u=++cnt;
    if(l==r)
    {
        tr[u].sum++;
        return;
    }
    int mid=(l+r)/2;
    if(x<=mid)change(tr[u].l,l,mid,x);
    else change(tr[u].r,mid+1,r,x);
    pushup(u);
}
int find(int u,int l,int r,int L,int R)
{
    if(!u)return 0;
    if(l==L&&r==R)return tr[u].sum;
    int mid=(l+r)/2;
    if(R<=mid)return find(tr[u].l,l,mid,L,R);
    else if(L>mid)return find(tr[u].r,mid+1,r,L,R);
    else return find(tr[u].l,l,mid,L,mid)+find(tr[u].r,mid+1,r,mid+1,R);
}
void solve(int u,int fa)
{
    for(int i=0;i<g[u].size();i++)
    {
        int v=g[u][i];
        if(v==fa)continue;
        solve(v,u);
        rt[u]=merge(rt[u],rt[v],1,n);
    }
    change(rt[u],1,n,ru2[u]);
    int res=find(rt[u],1,n,ru2[u],chu2[u])-1;
    res=res*(res-1)/2;
    ans+=res;
}
signed main()
{
    n=read();
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        b[y]=1;
        g[x].push_back(y);
    }
    for(int i=1;i<=n;i++)
        if(!b[i])root1=i;
        else b[i]=0;
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        b[y]=1;
        G[x].push_back(y);
    }
    for(int i=1;i<=n;i++)
        if(!b[i])root2=i;
        else b[i]=0;
    dfs(root1,0);
    t=0;
    dfs2(root2,0);
    solve(root1,0);
    cout<<ans;
    return 0;
}