Question
给一棵树,树上有一些关键节点,要求你选 个点,使得关键节点到这些点中距离的最小值的最大值最小,求这个值。
Solution
显然这道题是可以二分的,二分的好处在于我们将原问题转化为了:
对于一个树,选中最少的节点,使得任意关键节点到选中节点的最小距离 ,请问需要选中多少个节点?
这样一个树上最小点覆盖问题。
以下两种算法都是 ,我觉得第一种算法比较好想。
贪心:
我们以号节点为根节点,贪心的从最深的叶子节点开始选中他的
级祖先(这里我们需要bfs来得到一个深度单调递增的单调栈),该点显然必须选中,然后对已覆盖的节点进行删除操作。
我思考了很久不明白如何写删除会比较合适,其实我们染色即可,一开始所有点的, 对于选中的点
染色为
,然后dfs暴力染色即可。
简单的证明:
反证法:对于一个最深的叶子节点,我们选择他的级祖先最优,那么比如说我们可以构造出一条单链,选中他的
级祖先不如选中他的
级祖先来的更优秀。
贪心+DP:
设:子树上未被覆盖的关键节点的最远距离。
子树上已选择的最近节点的距离。
已经选中的节点数目
初始化:
转移方程式:
会有以下三种情况:
①.,此时有两种情况,子树已经有未被覆盖的关键节点则
;子树不含有未被覆盖的节点,当前节点为未被覆盖的节点则
②.说明选中的节点可以覆盖子树,
③.说明当前节点必须要选,如果不选则子树无法完全覆盖了,
。
注意点:特判根节点。
仔细观察上述式子我们会发现和
只有一个是正确的,且
时,子树才被完全覆盖,我们在dfs回溯到根节点的时候没有对根节点是否被完全覆盖进行判断,所以需要特判一下根节点。
Code1
#include <bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define ls (p << 1)
#define rs (ls | 1)
#define tm ((tl + tr) >> 1)
#define lowbit(x) ((x) & -(x))
using namespace std;
using ll = long long;
using ull = unsigned long long;
using pii = pair<int, int>;
constexpr double eps = 1e-8;
constexpr int NINF = 0xc0c0c0c0;
constexpr int INF = 0x3f3f3f3f;
constexpr ll LNINF = 0xc0c0c0c0c0c0c0c0;
constexpr ll LINF = 0x3f3f3f3f3f3f3f3f;
constexpr ll mod = 1e9 + 7;
constexpr ll N = 3e5 + 5;
int n, m, key[N], f[N], vis[N], q[N], tail, head;
vector<int> G[N];
void bfs() {
head = 1;
tail = 0;
q[++tail] = 1;
while (head <= tail) {
int u = q[head++];
for (auto v : G[u]) {
if (!f[v]) {
f[v] = u;
q[++tail] = v;
}
}
}
}
void update(int u) {
if (!vis[u]) return;
for (auto v : G[u]) {
if (vis[v] < vis[u] - 1) {
vis[v] = vis[u] - 1;
update(v);
}
}
}
bool check(int mid) {
memset(vis, -1, sizeof vis);
int ans = 0;
for (int i = tail; i >= 1; i--) {
if (vis[q[i]] == -1 && key[q[i]]) {
ans++;
int j = q[i];
for (int k = 1; k <= mid; k++) j = f[j];
vis[j] = mid;
update(j);
}
}
return ans <= m;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> key[i];
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
}
bfs();
int L = 0, R = n, mid, ans = n;
while (L < R) {
mid = L + (R - L) / 2;
if (check(mid)) {
R = ans = mid;
} else {
L = mid + 1;
}
}
cout << ans << '\n';
return 0;
}Code2
#include <bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define ls (p << 1)
#define rs (ls | 1)
#define tm ((tl + tr) >> 1)
#define lowbit(x) ((x) & -(x))
using namespace std;
using ll = long long;
using ull = unsigned long long;
using pii = pair<int, int>;
constexpr double eps = 1e-8;
constexpr int NINF = 0xc0c0c0c0;
constexpr int INF = 0x3f3f3f3f;
constexpr ll LNINF = 0xc0c0c0c0c0c0c0c0;
constexpr ll LINF = 0x3f3f3f3f3f3f3f3f;
constexpr ll mod = 1e9 + 7;
constexpr ll N = 3e5 + 5;
int n, m, key[N], cnt, f[N], g[N];
vector<int> G[N];
void dfs(int u, int p, int mid) {
f[u] = NINF, g[u] = INF;
for (auto v : G[u]) {
if (v == p) continue;
dfs(v, u, mid);
f[u] = max(f[u], f[v] + 1);
g[u] = min(g[u], g[v] + 1);
}
if (key[u] && g[u] > mid) f[u] = max(f[u], 0);
if (f[u] + g[u] <= mid) f[u] = NINF;
if (f[u] == mid) f[u] = NINF, g[u] = 0, cnt++;
}
bool check(int mid) {
cnt = 0;
dfs(1, -1, mid);
if (f[1] >= 0) cnt++;
return cnt <= m;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> key[i];
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
}
int L = 0, R = n, mid, ans = n;
while (L < R) {
mid = L + (R - L) / 2;
if (check(mid)) {
R = ans = mid;
} else {
L = mid + 1;
}
}
cout << ans << '\n';
return 0;
}
京公网安备 11010502036488号