Question
给一棵树,树上有一些关键节点,要求你选 个点,使得关键节点到这些点中距离的最小值的最大值最小,求这个值。
Solution
显然这道题是可以二分的,二分的好处在于我们将原问题转化为了:
对于一个树,选中最少的节点,使得任意关键节点到选中节点的最小距离 ,请问需要选中多少个节点?
这样一个树上最小点覆盖问题。
以下两种算法都是 ,我觉得第一种算法比较好想。
贪心:
我们以 号节点为根节点,贪心的从最深的叶子节点开始选中他的 级祖先(这里我们需要bfs来得到一个深度单调递增的单调栈),该点显然必须选中,然后对已覆盖的节点进行删除操作。
我思考了很久不明白如何写删除会比较合适,其实我们染色即可,一开始所有点的 , 对于选中的点 染色为 ,然后dfs暴力染色即可。
简单的证明:
反证法:对于一个最深的叶子节点,我们选择他的 级祖先最优,那么比如说我们可以构造出一条单链,选中他的 级祖先不如选中他的 级祖先来的更优秀。贪心+DP:
设:子树上未被覆盖的关键节点的最远距离。
子树上已选择的最近节点的距离。
已经选中的节点数目
初始化:
转移方程式:
会有以下三种情况:
①. ,此时有两种情况,子树已经有未被覆盖的关键节点则;子树不含有未被覆盖的节点,当前节点为未被覆盖的节点则
②. 说明选中的节点可以覆盖子树,
③. 说明当前节点必须要选,如果不选则子树无法完全覆盖了,。
注意点:特判根节点 。
仔细观察上述式子我们会发现 和 只有一个是正确的,且 时,子树才被完全覆盖,我们在dfs回溯到根节点的时候没有对根节点是否被完全覆盖进行判断,所以需要特判一下根节点。
Code1
#include <bits/stdc++.h> #define fi first #define se second #define mp make_pair #define pb push_back #define ls (p << 1) #define rs (ls | 1) #define tm ((tl + tr) >> 1) #define lowbit(x) ((x) & -(x)) using namespace std; using ll = long long; using ull = unsigned long long; using pii = pair<int, int>; constexpr double eps = 1e-8; constexpr int NINF = 0xc0c0c0c0; constexpr int INF = 0x3f3f3f3f; constexpr ll LNINF = 0xc0c0c0c0c0c0c0c0; constexpr ll LINF = 0x3f3f3f3f3f3f3f3f; constexpr ll mod = 1e9 + 7; constexpr ll N = 3e5 + 5; int n, m, key[N], f[N], vis[N], q[N], tail, head; vector<int> G[N]; void bfs() { head = 1; tail = 0; q[++tail] = 1; while (head <= tail) { int u = q[head++]; for (auto v : G[u]) { if (!f[v]) { f[v] = u; q[++tail] = v; } } } } void update(int u) { if (!vis[u]) return; for (auto v : G[u]) { if (vis[v] < vis[u] - 1) { vis[v] = vis[u] - 1; update(v); } } } bool check(int mid) { memset(vis, -1, sizeof vis); int ans = 0; for (int i = tail; i >= 1; i--) { if (vis[q[i]] == -1 && key[q[i]]) { ans++; int j = q[i]; for (int k = 1; k <= mid; k++) j = f[j]; vis[j] = mid; update(j); } } return ans <= m; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cin >> n >> m; for (int i = 1; i <= n; i++) cin >> key[i]; for (int i = 1; i < n; i++) { int u, v; cin >> u >> v; G[u].push_back(v); G[v].push_back(u); } bfs(); int L = 0, R = n, mid, ans = n; while (L < R) { mid = L + (R - L) / 2; if (check(mid)) { R = ans = mid; } else { L = mid + 1; } } cout << ans << '\n'; return 0; }
Code2
#include <bits/stdc++.h> #define fi first #define se second #define mp make_pair #define pb push_back #define ls (p << 1) #define rs (ls | 1) #define tm ((tl + tr) >> 1) #define lowbit(x) ((x) & -(x)) using namespace std; using ll = long long; using ull = unsigned long long; using pii = pair<int, int>; constexpr double eps = 1e-8; constexpr int NINF = 0xc0c0c0c0; constexpr int INF = 0x3f3f3f3f; constexpr ll LNINF = 0xc0c0c0c0c0c0c0c0; constexpr ll LINF = 0x3f3f3f3f3f3f3f3f; constexpr ll mod = 1e9 + 7; constexpr ll N = 3e5 + 5; int n, m, key[N], cnt, f[N], g[N]; vector<int> G[N]; void dfs(int u, int p, int mid) { f[u] = NINF, g[u] = INF; for (auto v : G[u]) { if (v == p) continue; dfs(v, u, mid); f[u] = max(f[u], f[v] + 1); g[u] = min(g[u], g[v] + 1); } if (key[u] && g[u] > mid) f[u] = max(f[u], 0); if (f[u] + g[u] <= mid) f[u] = NINF; if (f[u] == mid) f[u] = NINF, g[u] = 0, cnt++; } bool check(int mid) { cnt = 0; dfs(1, -1, mid); if (f[1] >= 0) cnt++; return cnt <= m; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); cin >> n >> m; for (int i = 1; i <= n; i++) cin >> key[i]; for (int i = 1; i < n; i++) { int u, v; cin >> u >> v; G[u].push_back(v); G[v].push_back(u); } int L = 0, R = n, mid, ans = n; while (L < R) { mid = L + (R - L) / 2; if (check(mid)) { R = ans = mid; } else { L = mid + 1; } } cout << ans << '\n'; return 0; }