题意

给一颗树,删除一条边再加一条边,使它仍为一颗树且任意两点间的距离的最大值最小。

题目数据范围描述有问题,n为1或重建不能使任意两点距离最大值变小,可以输出任意答案。

分析

删除一条边后会使它变成两颗树,两棵树的直径的中点相连一定是使距离最小的

红色的边为删除重建的边

在树上dp维护每个子树的最大直径\(h[x]\),和去除这个子树后的树的最大直径\(t[x]\),u为x的父亲,删除u-x这条边并重建后的树的最大直径为

\[\max\{\frac{h[x]+1}{2}+\frac{t[x]+1}{2}+1,h[x],t[x]\}\]

\(g[u]\)为以\(u\)为根的子树中\(u\)能到达的最远距离

\(p[u]\)为去除以\(u\)为根的子树后\(u\)能到达的最远距离

自底向上

x为u的孩子,\(mx1\),\(mx2\)分别为\(g[x]\)的最大值和次大值

  • \(g[u]=max(g[x]+1)\)
  • \(h[u]=max\{h[x],g[u],mx1+mx2+2\}​\)

自顶向下

k为x的兄弟,\(mx1​\),\(mx2​\)分别为\(g[k]\)的最大值和次大值

  • \(p[x]=max(p[u]+1,g[k]+2)​\)
  • \(t[x]=max \{p[u],h[k],p[u]+g[k]+1,mx1+mx2+2 \}​\)

然后bfs找重建的边
实现细节很多,我写的比较乱,建议自己根据dp式子模拟一下

Code

#include<bits/stdc++.h>
#define fi first
#define se second
#define bug cout<<"--------------"<<endl
using namespace std;
typedef long long ll;
const double PI=acos(-1.0);
const double eps=1e-6;
const int inf=1e9;
const ll llf=1e18;
const int mod=1e9+7;
const int maxn=3e5+10;
int n;
vector<int>f[maxn];
typedef pair<int,int> pii;
pii e[maxn];
int g[maxn],h[maxn],p[maxn],t[maxn];
int ans=inf;
pii ans1,ans2;
void dfs1(int u,int fa){
    int mx1=-inf,mx2=-inf;
    int len=f[u].size();
    int po=len;
    for(int i=0;i<len;i++){
        int x=f[u][i];
        if(x==fa){
            continue;
        }
        dfs1(x,u);
        g[u]=max(g[x]+1,g[u]);
        if(g[x]>mx1){
            mx2=mx1;
            mx1=g[x];
        }else if(g[x]>mx2){
            mx2=g[x];
        }
        h[u]=max(h[u],h[x]);
    }
    h[u]=max(h[u],g[u]);
    h[u]=max(mx1+mx2+2,h[u]);
}
int pre[maxn],suf[maxn];
int pr[maxn],sf[maxn];
void dfs2(int u,int fa){
    int len=f[u].size();
    vector<int>q;
    q.push_back(0);
    for(int i=0;i<len+5;i++) pre[i]=suf[i]=pr[i]=sf[i]=-inf;
    for(int i=0;i<len;i++){
        int x=f[u][i];
        if(x!=fa) q.push_back(x);
    }
    len=q.size()-1;
    for(int i=1;i<=len;i++){
        int x=q[i];
        t[x]=max(t[x],t[u]);
        pre[i]=max(pre[i-1],g[x]);
        pr[i]=max(pr[i-1],h[x]);
    }
    for(int i=len;i>=1;i--){
        int x=q[i];
        suf[i]=max(suf[i+1],g[x]);
        sf[i]=max(sf[i+1],h[x]);        
    }
    for(int i=1;i<=len;i++){
        int x=q[i];
        p[x]=max(p[x],p[u]+1);
        t[x]=max(p[u],t[x]);
    }
    for(int i=2;i<=len;i++){
        int x=q[i];
        t[x]=max(p[u]+1+pre[i-1],t[x]);
        t[x]=max(pr[i-1],t[x]);
        p[x]=max(pre[i-1]+2,p[x]);
    }
    for(int i=1;i<len;i++){
        int x=q[i];
        t[x]=max(p[u]+1+suf[i+1],t[x]);
        t[x]=max(sf[i+1],t[x]);
        p[x]=max(suf[i+1]+2,p[x]);      
    }
    for(int i=2;i<len;i++){
        int x=q[i];
        t[x]=max(t[x],pre[i-1]+suf[i+1]+2);
    }
    int mx1=-inf,mx2=-inf;
    for(int i=1;i<=len;i++){
        int x=q[i];
        t[x]=max(mx1+mx2+2,t[x]);
        if(g[x]>mx1){
            mx2=mx1;
            mx1=g[x];
        }else if(g[x]>mx2){
            mx2=g[x];
        }
    }
    mx1=mx2=-inf;
    for(int i=len;i>=1;i--){
        int x=q[i];
        t[x]=max(mx1+mx2+2,t[x]);
        if(g[x]>mx1){
            mx2=mx1;
            mx1=g[x];
        }else if(g[x]>mx2){
            mx2=g[x];
        }
    }
    for(int i=1;i<=len;i++){
        int x=q[i];
        int dis=max(max(t[x],h[x]),(t[x]+1)/2+(h[x]+1)/2+1);
        if(dis<ans){
            ans=dis;
            ans1=pii(x,u);
        }
    }
    for(int i=0;i<(int)f[u].size();i++){
        int x=f[u][i];
        if(x==fa) continue;
        dfs2(x,u);
    }
}
int pe[maxn],vis[maxn];
queue<int>q;
int bfs(int fa){
    int ret=fa;
    memset(vis,0,sizeof(vis));
    memset(pe,0,sizeof(pe));
    q.push(fa);
    vis[fa]=1;
    while(!q.empty()){
        int u=q.front();
        q.pop();
        ret=u;
        int len=f[u].size();
        for(int i=0;i<len;i++){
            if(!vis[f[u][i]]){
                q.push(f[u][i]);
                pe[f[u][i]]=u;
                vis[f[u][i]]=1; 
            }
        }
    }
    return ret;
}
int fq[maxn],tot;
void dfs(int u,int s){
    if(u==0) return;
    fq[++tot]=u;
    dfs(pe[u],s);
}
int find(int x){
    tot=0;
    int s=bfs(x);
    int t=bfs(s);
    dfs(t,s);
    return fq[(tot+1)/2];
}
void work(){
    for(int i=1;i<=n;i++){
        f[i].clear();
    }
    for(int i=1;i<n;i++){
        int a=e[i].fi,b=e[i].se;
        if(a==ans1.fi&&b==ans1.se) continue;
        if(b==ans1.fi&&a==ans1.se) continue;
        f[a].push_back(b);
        f[b].push_back(a);
    }
    ans2.fi=find(ans1.fi);
    ans2.se=find(ans1.se);
    cout<<ans<<endl;
    cout<<ans1.fi<<" "<<ans1.se<<endl;
    cout<<ans2.fi<<" "<<ans2.se<<endl;
}
int main(){
    ios::sync_with_stdio(false);
    //freopen("in","r",stdin);
    cin>>n;
    for(int i=1,a,b;i<n;i++){
        cin>>a>>b;
        f[a].push_back(b);
        f[b].push_back(a);
        e[i]=pii(a,b);
    }
    dfs1(1,0);
    dfs2(1,0);
    work();
    return 0;
}