题目大意:

给你两棵树,在这两棵树上分别找一个点,将其连接,使得\sum_{i=1}^{n-1}\sum_{j=i+1}^{n}dis(i,j) 最小,其中dis(i,j)表示从节点 i 到节点 j 的边数。

Part1

首先我们需要判断我们找的这两个点应该是哪个点,对于两棵树,他们的 dis 和是固定的,因此我们需要讨论将两个点连接起来所增加的花费。

假设需要连接的两棵树A,B,两棵树上进行连接的点为 u,v ,

点 u,v 到其所在子树其他点的距离之和为Dis_u,Dis_v ,A,B 上点的个数为 P_A,P_B,

那么将其连接后增加的 dis 值为:

Dis_u*P_B+Dis_v*P_A+P_A*P_B

很容易理解:

对于树 A 上的任意一个点 w ,我们需要将其和 B 上的所有点进行一次连接,等同于需要将dis(w,u) 重复计算 P_B 次,其他点同理,因此 A 树上增加的 dis 值为 Dis_u*P_B ,B树同理。

而对于刚建立的通道 dis(u,v)=1 被使用了 P_A*P_B 因此总的增加量即为上式。

P_A,P_B为定值,所以我们只需要最小化 Dis_u,Dis_v 即可。

Part2

现在的问题已经简化成了如何求一棵树上的 Dis 的最小值。

首先我们需要一遍dfs将树的根节点的 Dis 值找出来,找出来之后,我们就使用换根dp计算出树上所有节点的Dis值。

然后找出 Dis 最小的点,进行连接,再次进行上述操作即可。

最后将所有节点的 Dis 全部求和,由于这个值是求的双向的,因此需要除以2。

代码如下(代码有些冗长):

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int,int>
const ll INF=0x3f3f3f3f3f3f3f3f;
const int maxn=1000000+10;
const ll mod=1e9+7;
vector<ll>vec[maxn];
ll n;
ll dp[maxn];
ll num[maxn];
ll vis[maxn];
ll vis2[maxn];
ll value[maxn];
ll flag,point;
void init(){//清空数组 
    memset(dp,0,sizeof(dp));
    memset(num,0,sizeof(num));
    memset(vis,0,sizeof(vis));
    memset(vis2,0,sizeof(vis2));
    memset(value,0,sizeof(value));
}
ll dfs(ll x){//dfs找所有的值,num表示以这个点为根节点下面有几个点(包括这个点,dp用),value是这个点到所有子树的距离之和
    vis[x]=flag;
    ll sum=0;
    for(ll i=0;i<vec[x].size();i++){
        ll y=vec[x][i];
        if(!vis[y]){
            int q=dfs(y);
            sum+=q;
            value[x]+=value[y]+q;
        }
    }
    num[x]=sum+1;
    return sum+1;
}
void Dp(ll x){//dp数组表示当前这个点到其他点的距离之和 
    vis2[x]=flag;
    for(ll i=0;i<vec[x].size();i++){
        ll y=vec[x][i];
        if(!vis2[y]){
            dp[y]=dp[x]-num[y]+point-num[y];//换根dp方程 
            Dp(y);
        }
    }
}
void solve(){
    for(int i=1;i<=n;i++) vec[i].clear();
    init();
    ll u,v;
    for(ll i=0;i<n-2;i++){
        scanf("%lld %lld",&u,&v);
        vec[u].push_back(v);
        vec[v].push_back(u);
    }
    //使用flag的值对第一颗子树和第二棵子树进行区分,下同
    flag=1; dfs(1); dp[1]=value[1]; flag++;
    for(ll i=1;i<=n;i++)
        if(!vis[i]){
            dfs(i);
            dp[i]=value[i];
            break;
        }

    //point是当前两个子树之一的点的个数
    point=0;
    for(int i=1;i<=n;i++) 
        if(vis[i]==1) 
            point++;
    flag=1;
    Dp(1); 
    flag++;
    point=n-point;
    for(ll i=1;i<=n;i++)
        if(!vis2[i]){
            Dp(i);
            break;
        }
    //找到两个树上到其他点距离最小的点 
    ll v1=INF,v2=INF,p1,p2;
    for(ll i=1;i<=n;i++){
        if(vis2[i]==1 && dp[i]<=v1){
            p1=i;
            v1=dp[i];
        }
        if(vis2[i]==2 && dp[i]<=v2){
            p2=i;
            v2=dp[i];
        }
    }
    vec[p1].push_back(p2);
    vec[p2].push_back(p1);

    //重新进行dfs 和 dp 计算出这棵大树的所有dp值
    init(); flag=1; dfs(1);
    dp[1]=value[1]; point=n; Dp(1);
    //a->b  b->a计算两次,所以 /2
    ll res=0;
    for(ll i=1;i<=n;i++) res+=dp[i];
    printf("%lld\n",res/2ll);
}

int main()
{
    while(~scanf("%lld",&n))
        solve();
    return 0;
}