写两种做法吧。
第一种是直接重心性质,第二种是树形dp+换根。
一、利用重心性质:
下面是树的重心的性质:
1.树中所有点到某个点的距离和中,到重心的距离和是最小的,如果有两个重心,他们的距离和一样。
2.把两棵树通过一条边相连,新的树的重心在原来两棵树重心的连线上。
3.一棵树添加或者删除一个节点,树的重心最多只移动一条边的位置。
4.一棵树最多有两个重心,且相邻。
摘自百度百科
这题就考察到了树的重心的第一个性质
即树中所有点到重心的距离和最小,而此题所求的最小的w,就是所有点到重心的距离之和。
所以考虑dfs树形dp求出重心,然后bfs求出每个点距离重心的距离进行累加求和即可。
#include <bits/stdc++.h> #define inf 100000000 #define x first #define y second #define pp pair<ll, ll> using namespace std; typedef long long ll; const ll N = 2e6+5; ll head[N], top = 0; ll n; ll son[N], vis[N]; ll ans, res, point; struct Edge { ll v, next; } edge[N * 2]; void init() { memset(head, -1, sizeof(head)); top = 0; memset(son, 0, sizeof(son)); ans = inf; } void addedge(ll u, ll v) { edge[top].v = v; edge[top].next = head[u]; head[u] = top++; } void dfs(ll u, ll fa) { son[u] = 1; ll Max = 0; for (ll i = head[u]; i != -1; i = edge[i].next) { ll v = edge[i].v; if (v == fa) continue; dfs(v, u); son[u] += son[v]; Max = max(Max, son[v]); } ll tmp = max(Max, n - son[u]); if (tmp < ans || tmp == ans && u < point) { ans = tmp; point = u; } } void bfs() { queue<pp> q; vis[point] = 1; q.push(make_pair(point, 0)); while (!q.empty()) { pp now = q.front(); q.pop(); for (ll i = head[now.x]; i != -1; i = edge[i].next) { ll to = edge[i].v; if (!vis[to]) { vis[to] = 1; q.push(make_pair(to, now.y + 1)); res += now.y + 1; } } } } int main() { init(); scanf("%lld", &n); for (ll i = 1; i < n; i++) { ll u, v; scanf("%lld%lld", &u, &v); addedge(u, v); addedge(v, u); } dfs(1, -1); bfs(); printf("%lld\n", res); return 0; }
如果不知道重心的这个性质也没关系,直接树形dp+换根也能做。
首先随便取一个点作为根dfs一下,求出每个子树的大小,对于这次dfs的W的值就是 sum+=size[v]-1
即计算每个父节点与子节点连接的边被多少点经过。
然后考虑换根 怎么换根呢
假设父节点为root,子节点为son
对于刚才计算的则有
什么意思呢? 考虑从根从父节点root到子节点son
那么对于son内部的所有点(包括son自己)要走的路就少了一条,就是减少了
对于其他点还有 个
他们要多走一步
所以在此dfs一下,更新每个点作为根的最小值就是答案了。
#include<bits/stdc++.h> using namespace std; const int maxn=1e6+50; vector<int> e[maxn]; int siz[maxn]; int ans,sum,n; void dfs(int x,int f){ siz[x]=1; for(auto it:e[x]){ if(it==f) continue; dfs(it,x); siz[x]+=siz[it]; } sum+=siz[x]-1; } void dfs_(int x,int f){ ans=min(ans,sum); for(auto it:e[x]){ if(it==f) continue; sum+=n-2*siz[it]; dfs_(it,x); sum-=n-2*siz[it]; } } int main(){ ios::sync_with_stdio(0);cin>>n; for(int i=1;i<n;i++){ int x,y;cin>>x>>y; e[x].push_back(y); e[y].push_back(x); } dfs(1,-1); ans=sum; dfs_(1,-1); cout<<ans<<endl; return 0; }