题目
算法标签: 图论, 树的直径, 树形 D P DP DP
思路
题目大意就是添加 1 / 2 1 / 2 1/2条边 使得从起点开始走访问所有的点, 使得总的路径长度最小, 因为原图是树, 因此最初的巡回距离是 2 ⋅ ( n − 1 ) 2 \cdot (n - 1) 2⋅(n−1), 将问题分为两部分, 首先是添加一条边

最优策略就是在树的直径的两端添加一条边, 如上图, 树的直径是粉色线段设为 d d d, 如果当前巡逻到最下面的点如果不添加边需要 d d d的长度, 如果添加一条边, 路径长度变为 1 1 1, 减少的距离就是 d − 1 d - 1 d−1
如果可以建造两条道路, 可以找到树的最长路径和次长路径, 然后分别在两端设置路径
那么如何计算树的次长路径呢?
可以在计算完最长路径后, 将路径上的边权标志为 − 1 -1 −1, 再进行统计直径, 因为边权是负数不会产生贡献, 那么次长路径也计算出来了
为什么第二次求树的直径不能两次 d f s dfs dfs直接求?
因为边权变为负数, 直接求是错误的, 只能使用树形 D P DP DP求解
因为每个点被访问一次, 算法时间复杂度 O ( n ) O(n) O(n)
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
using namespace std;
const int N = 1e6 + 10, M = N << 1;
int n, k;
int head[N], ed[M], ne[M], w[M], idx;
int dist[N], len1, len2;
int from[N];
bool vis[N];
void add(int u, int v, int val) {
ed[idx] = v, ne[idx] = head[u], w[idx] = val, head[u] = idx++;
}
void dfs1(int u, int fa) {
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == fa) continue;
dist[v] = dist[u] + w[i];
dfs1(v, u);
}
}
void calc() {
// 第一次DFS找最远点
memset(dist, 0, sizeof dist);
dfs1(1, -1);
int max_u = 1;
for (int i = 1; i <= n; ++i) {
if (dist[i] > dist[max_u]) max_u = i;
}
// 第二次DFS找直径
memset(dist, 0, sizeof dist);
dfs1(max_u, -1);
int max_v = max_u;
for (int i = 1; i <= n; ++i) {
if (dist[i] > dist[max_v]) max_v = i;
}
len1 = dist[max_v];
if (k == 1) return;
// 记录直径路径
memset(vis, 0, sizeof vis);
queue<int> q;
q.push(max_v);
vis[max_v] = true;
from[max_v] = -1;
while (!q.empty()) {
int u = q.front();
q.pop();
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (vis[v] || dist[v] != dist[u] - w[i]) continue;
vis[v] = true;
from[v] = u;
q.push(v);
}
}
// 将直径上的边权设为-1
for (int u = max_u; u != -1; u = from[u]) {
for (int i = head[u]; ~i; i = ne[i]) {
if (ed[i] == from[u]) {
w[i] = -1;
w[i ^ 1] = -1;
break;
}
}
}
}
void dp(int u, int fa) {
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == fa) continue;
dp(v, u);
len2 = max(len2, dist[u] + dist[v] + w[i]);
dist[u] = max(dist[u], dist[v] + w[i]);
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
memset(head, -1, sizeof head);
cin >> n >> k;
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
add(u, v, 1);
add(v, u, 1);
}
calc();
if (k == 1) {
cout << 2 * (n - 1) - (len1 - 1) << endl;
return 0;
}
memset(dist, 0, sizeof dist);
dp(1, -1);
cout << 2 * (n - 1) - (len1 - 1) - (len2 - 1) << endl;
return 0;
}


京公网安备 11010502036488号