Question

给一棵树,树上有一些关键节点,要求你选 个点,使得关键节点到这些点中距离的最小值的最大值最小,求这个值。

Solution

显然这道题是可以二分的,二分的好处在于我们将原问题转化为了:
对于一个树,选中最少的节点,使得任意关键节点到选中节点的最小距离 ,请问需要选中多少个节点?
这样一个树上最小点覆盖问题。
以下两种算法都是 ,我觉得第一种算法比较好想。

  1. 贪心:
    我们以 号节点为根节点,贪心的从最深的叶子节点开始选中他的 级祖先(这里我们需要bfs来得到一个深度单调递增的单调栈),该点显然必须选中,然后对已覆盖的节点进行删除操作。
    我思考了很久不明白如何写删除会比较合适,其实我们染色即可,一开始所有点的 , 对于选中的点 染色为 ,然后dfs暴力染色即可。
    简单的证明:
    反证法:对于一个最深的叶子节点,我们选择他的 级祖先最优,那么比如说我们可以构造出一条单链,选中他的 级祖先不如选中他的 级祖先来的更优秀。

  2. 贪心+DP:
    设:

    子树上未被覆盖的关键节点的最远距离。

    子树上已选择的最近节点的距离。

    已经选中的节点数目
    初始化:



    转移方程式:



    会有以下三种情况:
    ①. ,此时有两种情况,子树已经有未被覆盖的关键节点则;子树不含有未被覆盖的节点,当前节点为未被覆盖的节点则
    ②. 说明选中的节点可以覆盖子树,
    ③. 说明当前节点必须要选,如果不选则子树无法完全覆盖了,
    注意点:特判根节点
    仔细观察上述式子我们会发现 只有一个是正确的,且 时,子树才被完全覆盖,我们在dfs回溯到根节点的时候没有对根节点是否被完全覆盖进行判断,所以需要特判一下根节点。

Code1

#include <bits/stdc++.h>

#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define ls (p << 1)
#define rs (ls | 1)
#define tm ((tl + tr) >> 1)
#define lowbit(x) ((x) & -(x))

using namespace std;
using ll = long long;
using ull = unsigned long long;
using pii = pair<int, int>;

constexpr double eps = 1e-8;
constexpr int NINF = 0xc0c0c0c0;
constexpr int INF = 0x3f3f3f3f;
constexpr ll LNINF = 0xc0c0c0c0c0c0c0c0;
constexpr ll LINF = 0x3f3f3f3f3f3f3f3f;
constexpr ll mod = 1e9 + 7;
constexpr ll N = 3e5 + 5;

int n, m, key[N], f[N], vis[N], q[N], tail, head;
vector<int> G[N];

void bfs() {
    head = 1;
    tail = 0;
    q[++tail] = 1;
    while (head <= tail) {
        int u = q[head++];
        for (auto v : G[u]) {
            if (!f[v]) {
                f[v] = u;
                q[++tail] = v;
            }
        }
    }
}

void update(int u) {
    if (!vis[u]) return;
    for (auto v : G[u]) {
        if (vis[v] < vis[u] - 1) {
            vis[v] = vis[u] - 1;
            update(v);
        }
    }
}

bool check(int mid) {
    memset(vis, -1, sizeof vis);
    int ans = 0;
    for (int i = tail; i >= 1; i--) {
        if (vis[q[i]] == -1 && key[q[i]]) {
            ans++;
            int j = q[i];
            for (int k = 1; k <= mid; k++) j = f[j];
            vis[j] = mid;
            update(j);
        }
    }
    return ans <= m;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> key[i];
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    bfs();
    int L = 0, R = n, mid, ans = n;
    while (L < R) {
        mid = L + (R - L) / 2;
        if (check(mid)) {
            R = ans = mid;
        } else {
            L = mid + 1;
        }
    }
    cout << ans << '\n';

    return 0;
}

Code2

#include <bits/stdc++.h>

#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define ls (p << 1)
#define rs (ls | 1)
#define tm ((tl + tr) >> 1)
#define lowbit(x) ((x) & -(x))

using namespace std;
using ll = long long;
using ull = unsigned long long;
using pii = pair<int, int>;

constexpr double eps = 1e-8;
constexpr int NINF = 0xc0c0c0c0;
constexpr int INF = 0x3f3f3f3f;
constexpr ll LNINF = 0xc0c0c0c0c0c0c0c0;
constexpr ll LINF = 0x3f3f3f3f3f3f3f3f;
constexpr ll mod = 1e9 + 7;
constexpr ll N = 3e5 + 5;

int n, m, key[N], cnt, f[N], g[N];
vector<int> G[N];

void dfs(int u, int p, int mid) {
    f[u] = NINF, g[u] = INF;
    for (auto v : G[u]) {
        if (v == p) continue;
        dfs(v, u, mid);
        f[u] = max(f[u], f[v] + 1);
        g[u] = min(g[u], g[v] + 1);
    }
    if (key[u] && g[u] > mid) f[u] = max(f[u], 0);
    if (f[u] + g[u] <= mid) f[u] = NINF;
    if (f[u] == mid) f[u] = NINF, g[u] = 0, cnt++;
}

bool check(int mid) {
    cnt = 0;
    dfs(1, -1, mid);
    if (f[1] >= 0) cnt++;
    return cnt <= m;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> key[i];
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    int L = 0, R = n, mid, ans = n;
    while (L < R) {
        mid = L + (R - L) / 2;
        if (check(mid)) {
            R = ans = mid;
        } else {
            L = mid + 1;
        }
    }
    cout << ans << '\n';

    return 0;
}