题目大意:
给你两棵树,在这两棵树上分别找一个点,将其连接,使得\sum_{i=1}^{n-1}\sum_{j=i+1}^{n}dis(i,j) 最小,其中dis(i,j)表示从节点 i 到节点 j 的边数。
Part1
首先我们需要判断我们找的这两个点应该是哪个点,对于两棵树,他们的 dis 和是固定的,因此我们需要讨论将两个点连接起来所增加的花费。
假设需要连接的两棵树A,B,两棵树上进行连接的点为 u,v ,
点 u,v 到其所在子树其他点的距离之和为Dis_u,Dis_v ,A,B 上点的个数为 P_A,P_B,
那么将其连接后增加的 dis 值为:
Dis_u*P_B+Dis_v*P_A+P_A*P_B
很容易理解:
对于树 A 上的任意一个点 w ,我们需要将其和 B 上的所有点进行一次连接,等同于需要将dis(w,u) 重复计算 P_B 次,其他点同理,因此 A 树上增加的 dis 值为 Dis_u*P_B ,B树同理。
而对于刚建立的通道 dis(u,v)=1 被使用了 P_A*P_B 因此总的增加量即为上式。
P_A,P_B为定值,所以我们只需要最小化 Dis_u,Dis_v 即可。
Part2
现在的问题已经简化成了如何求一棵树上的 Dis 的最小值。
首先我们需要一遍dfs将树的根节点的 Dis 值找出来,找出来之后,我们就使用换根dp计算出树上所有节点的Dis值。
然后找出 Dis 最小的点,进行连接,再次进行上述操作即可。
最后将所有节点的 Dis 全部求和,由于这个值是求的双向的,因此需要除以2。
代码如下(代码有些冗长):
#include<bits/stdc++.h> using namespace std; #define ll long long #define pii pair<int,int> const ll INF=0x3f3f3f3f3f3f3f3f; const int maxn=1000000+10; const ll mod=1e9+7; vector<ll>vec[maxn]; ll n; ll dp[maxn]; ll num[maxn]; ll vis[maxn]; ll vis2[maxn]; ll value[maxn]; ll flag,point; void init(){//清空数组 memset(dp,0,sizeof(dp)); memset(num,0,sizeof(num)); memset(vis,0,sizeof(vis)); memset(vis2,0,sizeof(vis2)); memset(value,0,sizeof(value)); } ll dfs(ll x){//dfs找所有的值,num表示以这个点为根节点下面有几个点(包括这个点,dp用),value是这个点到所有子树的距离之和 vis[x]=flag; ll sum=0; for(ll i=0;i<vec[x].size();i++){ ll y=vec[x][i]; if(!vis[y]){ int q=dfs(y); sum+=q; value[x]+=value[y]+q; } } num[x]=sum+1; return sum+1; } void Dp(ll x){//dp数组表示当前这个点到其他点的距离之和 vis2[x]=flag; for(ll i=0;i<vec[x].size();i++){ ll y=vec[x][i]; if(!vis2[y]){ dp[y]=dp[x]-num[y]+point-num[y];//换根dp方程 Dp(y); } } } void solve(){ for(int i=1;i<=n;i++) vec[i].clear(); init(); ll u,v; for(ll i=0;i<n-2;i++){ scanf("%lld %lld",&u,&v); vec[u].push_back(v); vec[v].push_back(u); } //使用flag的值对第一颗子树和第二棵子树进行区分,下同 flag=1; dfs(1); dp[1]=value[1]; flag++; for(ll i=1;i<=n;i++) if(!vis[i]){ dfs(i); dp[i]=value[i]; break; } //point是当前两个子树之一的点的个数 point=0; for(int i=1;i<=n;i++) if(vis[i]==1) point++; flag=1; Dp(1); flag++; point=n-point; for(ll i=1;i<=n;i++) if(!vis2[i]){ Dp(i); break; } //找到两个树上到其他点距离最小的点 ll v1=INF,v2=INF,p1,p2; for(ll i=1;i<=n;i++){ if(vis2[i]==1 && dp[i]<=v1){ p1=i; v1=dp[i]; } if(vis2[i]==2 && dp[i]<=v2){ p2=i; v2=dp[i]; } } vec[p1].push_back(p2); vec[p2].push_back(p1); //重新进行dfs 和 dp 计算出这棵大树的所有dp值 init(); flag=1; dfs(1); dp[1]=value[1]; point=n; Dp(1); //a->b b->a计算两次,所以 /2 ll res=0; for(ll i=1;i<=n;i++) res+=dp[i]; printf("%lld\n",res/2ll); } int main() { while(~scanf("%lld",&n)) solve(); return 0; }