题意
给出一颗树,找两条没有共点的路径,记两条路径经过的点的个数为 a、b,求 pair(a,b) 的种类数
题解
考虑枚举结点,作为经过他路径的LCA时,还能找到合法的最长路径
首先,我们将直径的端点作为树根,变成有根树
考虑树根,经过他的最长路径就是直径 a,所以,选了直径后,再选一条最长的路径(去掉直径上的点后的图的新直径) b,就得到一个范围 a⋅b
考虑非直径上的结点,他们作为LCA的路径的长度肯定小于等于 b, 再找一条合法的最长路径一定是直径 a ,所以,这些点产生的范围都被包含在树根的范围内了,不需要考虑了
现在只剩下直径上的点,对于点 x, 经过他的最长路径(前提是他要作为LCA),一部分是向下的直径部分,另一部分是除了直径部分的最长路,也就是说次深的路径。
还需要找另一条的最长路径
首先,这条路径的端点一定是树根!!!
这里可以简单证明,不再赘述
也就是说,要找到树根向下的不经过当前点的最长路,这个求法比较简单,就是求出直径上的点向直径以外延伸的最长路,加上到树根的距离
至此,我们就得出来很多的范围,下面是求答案
对于每一个 pair(a,b) 我们只需要维护 a 的配对最大值, b 的配对最大值
最后,对于 a 的升序,b 应该为降序
遍历 a ,累加上 b 就是最后的答案
代码
#include<bits/stdc++.h>
#define N 100010
#define INF 0x3f3f3f3f
#define eps 1e-6
#define pi 3.141592653589793
#define mod 998244353
#define P 1000000007
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<<x<<endl
#define mem(x,y) memset(x,0,sizeof(int)*(y+3))
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
vector<int> G[N];
int q[N],d[N],w[N],flag[N],c[N],cnt,tar,st,mx,n;
void findpath(int x,int fa,int dep){
if (dep==mx){q[++cnt]=x; return; }
for(auto i:G[x]) if (i!=fa){
findpath(i,x,dep+1);
if (cnt){
q[++cnt]=x;
return;
}
}
}
void dfs(int x,int fa,int dep){
c[x]=1;
for(auto i:G[x]) if (i!=fa&&!flag[i]){
if (dep+1>mx) { mx=dep+1;tar=i;}
dfs(i,x,dep+1);
}
}
int main(int argc, char const *argv[])
{
int T; sc(T);
while(T--){
sc(n);
mem(flag,n); mem(w,n);
for(int i=1;i<=n;i++) G[i].cl();
for(int i=1,x,y;i<n;i++){
scc(x,y);
G[x].pb(y); G[y].pb(x);
}
mx=0; dfs(1,-1,0);
mx=0; dfs(tar,-1,0);
int ans1=mx+1,ans2=0;
cnt=0; findpath(tar,-1,0);
for(int i=1;i<=cnt;i++) flag[q[i]]=1;
for(int i=1;i<=cnt;i++) {
mx=0; dfs(q[i],-1,0);
d[i]=mx+1;
}
for(int i=cnt;i>0;i--) w[i]=max(w[i+1],d[i]+cnt-i);
mem(c,n); mem(q,n);
for(int i=1;i<=n;i++) if (!flag[i]&&!c[i]){
mx=0; dfs(i,-1,0);
mx=0; dfs(tar,-1,0);
ans2=max(ans2,mx+1);
}
q[ans2]=ans1; q[ans1]=ans2;
mx=0;
for(int i=1;i<cnt;i++){
mx=max(mx,d[i]+i-1);
q[mx]=max(q[mx],w[i+1]);
q[w[i+1]]=max(q[w[i+1]],mx);
}
for(int i=n;i>0;i--){
if (!q[i]) q[i]=q[i+1];
q[i]=max(q[i],q[i+1]);
}
LL ans=0;
for(int i=1;i<=n;i++) ans+=q[i];
printf("%lld\n",ans);
}
return 0;
}