题目描述
给定一棵大小为 的树,让你断一条边,使得树上距离最大值最小。
正解
先求出树直径的两个端点 ,,最后断的边肯定是直径上的边。
考虑枚举这条边,然后快速统计答案。
先可以预处理出直径这一条链,然后求出链上前缀的直径和后缀的直径就可以 统计答案了。
断掉直径上的边后,两边联通块的直径肯定分别是 (考虑树直径的性质,反证法可以证明)。
那么好办,对于直径上的每一个点,求出它挂出去的链的最远距离即可。
时间复杂度 。
出题人不给大样例还得自己打对拍,差评。
#include <bits/stdc++.h> #define N 1000005 #define inf 1000000005 using namespace std; int n, rt1, rt2; int head[N], nex[N << 1], to[N << 1], eVal[N << 1], ecnt; int mDist[N], f[N], g[N]; bool inList[N]; vector<int> arr; inline int read() { int x = 0; char ch = getchar(); while(!isdigit(ch)) ch = getchar(); while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar(); return x; } inline void addE(int u, int v, int w) { to[++ecnt] = v, eVal[ecnt] = w; nex[ecnt] = head[u], head[u] = ecnt; } int que[N], dist[N], pre[N]; void getr(int s, int &rt) { rt = 0, dist[0] = -1; int hd = 0, tl = -1, u, v; for(int i = 1; i <= n; ++i) dist[i] = inf; dist[s] = 0, que[++tl] = s; while(hd <= tl) { u = que[hd++]; if(dist[u] > dist[rt]) rt = u; for(int i = head[u]; i; i = nex[i]) { v = to[i]; if(dist[v] > dist[u] + eVal[i]) { dist[v] = dist[u] + eVal[i]; pre[v] = u, que[++tl] = v; } } } } void getList() { getr(1, rt1); getr(rt1, rt2); int u = rt2; inList[u] = true; arr.push_back(u); while(u != rt1) { u = pre[u]; inList[u] = true; arr.push_back(u); } reverse(arr.begin(), arr.end()); } void getDist(int s) { int hd = 0, tl = -1, u, v; que[++tl] = s; while(hd <= tl) { u = que[hd++]; mDist[s] = max(mDist[s], dist[u]); for(int i = head[u]; i; i = nex[i]) { v = to[i]; if(dist[v] > dist[u] + eVal[i]) { dist[v] = dist[u] + eVal[i]; que[++tl] = v; } } } } int main() { n = read(); for(int i = 1, u, v, w; i < n; ++i) { u = read(), v = read(), w = read(); addE(u, v, w), addE(v, u, w); } getList(); for(int i = 1; i <= n; ++i) dist[i] = inf; for(auto u : arr) dist[u] = 0; for(auto u : arr) getDist(u); getr(rt1, rt2); // 重新得到距离 for(int i = 0, u; i < arr.size(); ++i) { u = arr[i]; f[i] = dist[u] + mDist[u]; if(i) f[i] = max(f[i], f[i - 1]); } for(int i = arr.size() - 1, u; ~i; --i) { u = arr[i]; g[i] = dist[rt2] - dist[u] + mDist[u]; if(i != arr.size() - 1) g[i] = max(g[i], g[i + 1]); } int ans = inf; for(int i = 0; i < arr.size() - 1; ++i) ans = min(ans, max(f[i], g[i + 1])); printf("%d\n", ans); return 0; }