题意

给出一颗树,找两条没有共点的路径,记两条路径经过的点的个数为 a b a、b ab,求 p a i r ( a , b ) pair(a,b) pair(a,b) 的种类数

题解

考虑枚举结点,作为经过他路径的LCA时,还能找到合法的最长路径

首先,我们将直径的端点作为树根,变成有根树

考虑树根,经过他的最长路径就是直径 a a a,所以,选了直径后,再选一条最长的路径(去掉直径上的点后的图的新直径) b b b,就得到一个范围 a b a\cdot b ab

考虑非直径上的结点,他们作为LCA的路径的长度肯定小于等于 b, 再找一条合法的最长路径一定是直径 a a a ,所以,这些点产生的范围都被包含在树根的范围内了,不需要考虑了

现在只剩下直径上的点,对于点 x x x, 经过他的最长路径(前提是他要作为LCA),一部分是向下的直径部分,另一部分是除了直径部分的最长路,也就是说次深的路径。
还需要找另一条的最长路径
首先,这条路径的端点一定是树根!!!
这里可以简单证明,不再赘述
也就是说,要找到树根向下的不经过当前点的最长路,这个求法比较简单,就是求出直径上的点向直径以外延伸的最长路,加上到树根的距离

至此,我们就得出来很多的范围,下面是求答案
对于每一个 p a i r ( a , b ) pair(a,b) pair(a,b) 我们只需要维护 a a a 的配对最大值, b b b 的配对最大值
最后,对于 a 的升序,b 应该为降序
遍历 a ,累加上 b 就是最后的答案

代码

#include<bits/stdc++.h>
#define N 100010
#define INF 0x3f3f3f3f
#define eps 1e-6
#define pi 3.141592653589793
#define mod 998244353
#define P 1000000007
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<<x<<endl
#define mem(x,y) memset(x,0,sizeof(int)*(y+3))
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;

vector<int> G[N];
int q[N],d[N],w[N],flag[N],c[N],cnt,tar,st,mx,n;

void findpath(int x,int fa,int dep){
    if (dep==mx){q[++cnt]=x; return; }
    for(auto i:G[x]) if (i!=fa){
        findpath(i,x,dep+1);
        if (cnt){
            q[++cnt]=x;
            return;
        }
    }
}

void dfs(int x,int fa,int dep){
    c[x]=1;
    for(auto i:G[x]) if (i!=fa&&!flag[i]){
        if (dep+1>mx) { mx=dep+1;tar=i;}
        dfs(i,x,dep+1);
    }
}

int main(int argc, char const *argv[])
{
    int T; sc(T);
    while(T--){
        sc(n);
        mem(flag,n);  mem(w,n);
        for(int i=1;i<=n;i++) G[i].cl();

        for(int i=1,x,y;i<n;i++){
            scc(x,y);
            G[x].pb(y); G[y].pb(x);
        }

        mx=0; dfs(1,-1,0);
        mx=0; dfs(tar,-1,0);

        int ans1=mx+1,ans2=0;

        cnt=0; findpath(tar,-1,0);
        
        for(int i=1;i<=cnt;i++) flag[q[i]]=1;
        for(int i=1;i<=cnt;i++) {
            mx=0; dfs(q[i],-1,0);
            d[i]=mx+1; 
        }
        for(int i=cnt;i>0;i--) w[i]=max(w[i+1],d[i]+cnt-i);

        mem(c,n); mem(q,n);
        for(int i=1;i<=n;i++) if (!flag[i]&&!c[i]){
            mx=0; dfs(i,-1,0);
            mx=0; dfs(tar,-1,0);
            ans2=max(ans2,mx+1);
        }
        q[ans2]=ans1; q[ans1]=ans2;

        mx=0;
        for(int i=1;i<cnt;i++){
            mx=max(mx,d[i]+i-1);
            q[mx]=max(q[mx],w[i+1]);
            q[w[i+1]]=max(q[w[i+1]],mx);
        }

        for(int i=n;i>0;i--){
            if (!q[i]) q[i]=q[i+1];
            q[i]=max(q[i],q[i+1]);
        }
        LL ans=0;
        for(int i=1;i<=n;i++) ans+=q[i];
        printf("%lld\n",ans);
    }
    return 0;
}