C.
题意
给两棵树,结点数分别为和,对于所有的点对(可以在一棵树内部,也可以分别在两棵树上),求出,其中函数表示给这两棵树任意加一条边联通后,这两点的最远距离。也就是求
分析
很明显,对于所有的点对,可以分为两种情况
- 和在同一棵树内
- 和分别在两棵树上
第一种情况(s和t在同一颗树内)
我们需要求的相当于是一棵树内任意两点的距离和。要处理这个东西,可以这么考虑,对于这棵树上的一条边,我们可以统计树上所有路径中,经过这条边的次数,那么这条边对距离和的贡献就等于,那么怎么求呢?可以这么考虑,所有的经过这条边的路径数量,相当于在左端点及其左边任选一个点,右端点及其右边任选一个点,这样选出来的路径肯定都有经过这条边。 所以我们只需要开一个数组记录的子树的大小,那么对于结点和它父亲的这条边,对应的总路径数,这里代表总的结点数,减去的子树大小,就等于在上面的结点个数。本题中所有的边权为1,分别为和。
第二种情况(s和t在不同的树上)
那么显然有一种贪心的想法,我们可以在第一棵树内找到距离最远的点,在第二棵树内找到距离最远的点,那么加的边就可以是,这样对应的是最大的。 所以接下来我们要做的就是快速地找出每棵树内,每个结点对应的最远距离,这里有两种做法:
针对情况二的做法1
赛时我用的是较为麻烦的做法,开两个数组和分别表示从结点开始,第一步往子节点和往父节点走,能走的最远距离。所以一个点能走到的最远距离,要么是往父节点走,要么是往子节点走,也就是,问题转化为如何求和。 那么数组是比较好得到的,直接简单的一遍,设当前结点为,子节点为,。 难点在于,不过也还好,想一下递推关系就可以知道,第一步往父节点走,最远距离等于父节点往其他子节点走的最远距离加上父节点到该子节点的距离,即 其中表示结点的兄弟结点,表示结点的父节点。所以只需要再进行一次即可求出。
针对情况二的做法2
赛后看到题解,发现上面的做法太麻烦了,直接利用一个性质:树上的一个点距离最远的点的距离,等于到树上直径的两个端点中较远的点的距离。直观上感觉确实这样,不过具体证明我也不会。 那么利用这个性质的话,就可以用求树的直径的方法,两遍,第一遍任选一个点出发,记录能到达的最远的点为,即为直径的一端,从点出发,再次,顺便维护树上所有点到的距离,能到达的最远的点就是直径的另一端,即为直径,再从出发一遍,维护树上所有点到点的距离,那么对于所有的点,能到达的最远距离就是 。 虽然这种方法需要的次数比较多,但是思路上还是比较简单的,不用考虑那么多转移,我个人感觉这种方法比前面讲的做法1要简单。
那么如果已经求出来了每个点在树内能到达的最远距离数组,下面考虑具体如何计算出最终答案。我是对着第一组样例来手动列一下算式总结。如对于两棵树都只有2个点,一条边,编号分别为1,2和3,4的情况,对于和分别在两棵树上时,有 可以看出,对于手动连的中间那条边,也就是式子中间的1,,总共有条。 对于左边的,有个点的与其相加,对于同理,所以提取公因式,等于。 同理,对于右边的和,求和为,不妨记第一棵个结点的树中,各结点的最远距离和为,第二棵树对应的节点距离和 为。 那么就等于。 再把这个答案加上开始求的在一棵树内的距离和,即为最终答案(记得开long long,我习惯宏定义int long long)。 下面分别为两种方法求第二种情况的代码: 第一种
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=2e5+10;
vector<int>G[maxn];
int siz[maxn],ans,dp1[maxn],dp2[maxn],fa[maxn];
void dfs(int u,int f,int tot){
siz[u]=1;
for(auto v:G[u]){
if(v==f) continue;
fa[v]=u;
dfs(v,u,tot);
dp1[u]=max(dp1[u],dp1[v]+1);
siz[u]+=siz[v];
}
ans+=siz[u]*(tot-siz[u]);
}
void dfs2(int u,int f){
int t=0;
dp2[u]=dp2[f];
for(auto v:G[f]){
if(v==fa[f])continue;
if(v==u)t=1;
else dp2[u]=max(dp2[u],dp1[v]+1);
}
dp2[u]+=t;
for(auto v:G[u]){
if(v==f) continue;
dfs2(v,u);
}
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
int n,m;cin>>n>>m;
for(int i=2;i<n+m;i++){
int u,v;cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,1,n);
dfs2(1,0);
dfs(n+1,n+1,m);
dfs2(n+1,0);
int sum1=0,sum2=0;
for(int i=1;i<=n;i++)
sum1+=max(dp1[i],dp2[i]);
for(int i=n+1;i<=n+m;i++)
sum2+=max(dp1[i],dp2[i]);
ans+=n*m+n*sum2+m*sum1;
cout<<ans<<endl;
return 0;
}
第二种
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=2e5+10;
vector<int>G[maxn];
int dis1[maxn],dis2[maxn],s,t,siz[maxn];
ll ans;
void dfs(int u,int f,int tot){
siz[u]=1;
for(auto v:G[u]){
if(v==f) continue;
dfs(v,u,tot);
siz[u]+=siz[v];
}
ans+=siz[u]*(tot-siz[u]);
}
void dfs1(int u,int f){
for(auto v:G[u]){
if(v==f) continue;
dis1[v]=dis1[u]+1;
if(dis1[v]>dis1[s]) s=v;
dfs1(v,u);
}
}
void dfs2(int u,int f){
for(auto v:G[u]){
if(v==f) continue;
dis2[v]=dis2[u]+1;
if(dis2[v]>dis2[t]) t=v;
dfs2(v,u);
}
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
int n,m;cin>>n>>m;
for(int i=1;i<=n+m-2;i++){
int u,v;cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1,1);
dfs2(s,s);//得到所有点到s的距离
dis1[t]=0;
dfs1(t,t);
s=t=0;
dfs1(n+1,n+1);
dfs2(s,s);
dis1[t]=0;
dfs1(t,t);
dfs(1,1,n);
dfs(n+1,n+1,m);
int sum1=0,sum2=0;
for(int i=1;i<=n;i++)
sum1+=max(dis1[i],dis2[i]);
for(int i=n+1;i<=n+m;i++)
sum2+=max(dis1[i],dis2[i]);
ans+=n*m+n*sum2+m*sum1;
cout<<ans<<endl;
return 0;
}