题目链接
大意:给你两个不连通的树,让你加一条边,使得两个树联通,并且使得 i n i + 1 n d i s ( i , j ) \sum_i^n\sum_{i+1}^ndis(i,j) ini+1ndis(i,j)最小。
思路:首先我们需要一个dfs将两个联通块的点求出来,然后再用一个dfs求出以每个点的子节点个数,和每个点的子节点的dis之和,用两个数组,cnt数组记录子节点个数,dp数组记录子节点的dis之和,设f为父节点,点集s为f的子节点集合,那么
d p [ f ] = i S d p [ i ] + c n t [ i ] dp[f]=\sum_{i\in S}dp[i]+cnt[i] dp[f]=iSdp[i]+cnt[i]。求完之后我们就需要连边了,我们想一下,显然我们需要在两个联通块中找到两个最优的点,
用贪心思想:找到一个点S,使得以S为根的时候dp[S]最小,那么这个就是我们要找的点,这个点就是重心,我们可以使用换根来解决。找到两个点连起来完事。
中间的贡献我们算一算即可,新连的边显然要使用 l . s i z e ( ) r . s i z e ( ) l.size()*r.size() l.size()r.size(),指两个联通块的大小,每个联通块内的边的使用次数也可以在一个dfs内求出,然后最后一个部分就是两个联通块内分别取一点的新贡献了,答案显然是 l . s i z e ( ) d p [ R ] + r . s i z e ( ) d p [ L ] l.size()*dp[R]+r.size()*dp[L] l.size()dp[R]+r.size()dp[L],,三个加起来就行了。
细节见代码:

#include<bits/stdc++.h>

#define LL long long
#define fi first
#define se second
#define mp make_pair
#define pb push_back

using namespace std;

LL gcd(LL a,LL b){return b?gcd(b,a%b):a;}
LL lcm(LL a,LL b){return a/gcd(a,b)*b;}
LL powmod(LL a,LL b,LL MOD){LL ans=1;while(b){if(b%2)ans=ans*a%MOD;a=a*a%MOD;b/=2;}return ans;}
const int N = 2e5 +11;
vector<int>v[N];
vector<int>l,r;
int n,vis[N];
void dfs(int now,int pre){
    l.pb(now);
    for(auto k:v[now]){
        if(k==pre)continue;
        dfs(k,now);
    }
}
LL a[N],cnt[N],dp[N];
LL ans;
void pa(int now,int pre,int len){
    cnt[now]=1;
    for(auto k:v[now]){
        if(k==pre)continue;
        pa(k,now,len);
        ans+=1ll*(len-cnt[k])*cnt[k];
        dp[now]+=dp[k]+cnt[k];
        cnt[now]+=cnt[k];
    }
}
LL va[3];
void chos(int now,int pre,int bb,int len){
    va[bb]=min(va[bb],dp[now]);
    for(auto k:v[now]){
        if(k==pre)continue;
        dp[k]=dp[k]+dp[now]-dp[k]-cnt[k]+len-cnt[k];
        int prx=cnt[now];
        int pry=cnt[k];
        cnt[now]=cnt[now]-cnt[k];
        cnt[k]+=cnt[now];
        chos(k,now,bb,len);
        cnt[k]=pry;
        cnt[now]=prx;
    }
}
int main(){
    ios::sync_with_stdio(false);
    cin>>n;
    va[1]=va[2]=1e18;
    for(int i=1;i<=n-2;i++){
        int s,t;
        cin>>s>>t;
        v[s].pb(t);
        v[t].pb(s);
    }    
    dfs(1,0);
    for(auto k:l)vis[k]=1;
    for(int i=1;i<=n;i++)if(!vis[i])r.pb(i);
    ans=1ll*l.size()*r.size();
    pa(l[0],0,l.size());
    pa(r[0],0,r.size());
    chos(l[0],0,1,l.size());
    chos(r[0],0,2,r.size());
    ans+=1ll*va[1]*r.size()+va[2]*l.size();
    cout<<ans<<endl;
    return 0;
}