P2495 消耗战

题意

我们有一个树,它的根节点为 1 ,它的每条边都有相应的权值,我们每次都会产生一些点,要求如果要将这些点通过删边的方式使之不与根节点相连,求删去边权和的最小值。

思路

暴力

如何暴力做?

f[i] 表示所有的处理完这点的最小值

第一步,当我们输入我们所有的点之后,我们将所有的点都标记上,然后将 f[n] 初始化为 0 。如果我们遇到这些标记,那么我们只需要将 f[i] 设置为 +inf 随后返回即可。对于每个节点我们只需要写f[i] += min(w , f[j]) 即可求出来。

#include <iostream>
#include <vector>
using namespace std;

const int N = 3e5 + 10;
vector<pair<int, int>> E[N];
long long f[N];
bool is[N];

void dfs(int u, int fa = -1) {
    if (is[u]) {
        f[u] = 1ll << 60;
        return;
    }
    f[u] = 0;
    for (auto [x, w] : E[u]) {
        if (x == fa) continue;
        dfs(x, u);
        f[u] += min(f[x], (long long)w);
    }
}

int main() {
    int n;
    cin >> n;

    for (int i = 1; i < n; i++) {
        int x, y, w;
        cin >> x >> y >> w;
        E[x].push_back({y, w});
        E[y].push_back({x, w});
    }

    int m;
    cin >> m;

    while (m--) {
        int k;
        cin >> k;
        vector<int> v(k);
        for (int i = 0; i < k; i++) cin >> v[i], is[v[i]] = true;
        dfs(1);
        cout << f[1] << '\n';
        for (int i = 0; i < k; i++) is[v[i]] = false;
    }

    return 0;
}

优化

透过这个过程,我们不难发现,对于一个点而言,删去这个点的最小值就是从该点出发到达根中的所有路径中的最小值,并将这个最小值更新为我们的删点最小值,对于多个点的情况而言,这将不会有任何影响。

然后,我们发现,如果我们取了每个点的最小值的话,那么对于这棵树我们甚至不需要一个完整的树就可以得到答案了,只需要那些关键点就足以算出答案。到此,我们就可以着手虚树了。

首先,我们可以得到一些关键点,然后我们将这些关键点按照树的前序遍历序排好。那么我们这个顺序就会变成先根后子树,先左子树后右子树。

我们来开一个栈,栈中的元素表示从根出发一直到最右边子树的那条链。对于两个子树中的节点的话,我们需要使用 lca 算法来完成。

最首先呢,我们需要将栈中放入根,然后我们按照我们的序逐个加入点k

我们设置栈顶元素为 x,栈顶的下一个元素为 y。序为dfn[ ]lc = lca(x ,k)

他们可能有如下几种情况

  1. dfn[lc] == dfn[x]
    也就是说 kx 的子树,那么将 k 入栈即可。 alt

  2. dfn[lc] != dfn[x] 这种情况分以下几类讨论。

    1. dfn[y] < dfn[lc]
      由我们上面所说的顺序,那么此时,ylc 都是 x 的祖先,那么此时 lcy 的子树。连接 lc->x。将 x 出栈,分别入栈 lck。加点完毕。
      alt

    2. dfn[y] == dfn[lc] 即,xk 的祖先为 y。连边 y->xx 出栈,入栈 k 。加点完毕。
      alt

    3. dfn[y] > dfn[lc] 也就是说 lc 也是 y 的祖先之一。连边y->x,然后将 x 出栈。然后将 x 设为栈顶,y 设为次位。一直循环到遇到上面情况为止。
      alt

至此,我们的虚树也就建立完成了。 我们只需要遍历这颗树就可以算出答案了。

#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;

// --- graph 1
const int N = 5e5 + 10;
int he[N], nxt[N * 2], v[N * 2], id = 1;
int w[N];
void add(int x, int y, long long val) {
    v[id] = y;
    w[id] = val;
    nxt[id] = he[x];
    he[x] = id++;
}

// --- base
int dfn[N], pre[N][25];
int deps[N], dpn[N];
long long mn[N];

void get_dfn(int u, int fa = -1, int deep = 1) {
    deps[deep] = u;
    dpn[u] = deep;
    dfn[u] = ++id;

    for (int i = 0; i < 25; i++) {
        if (deep - (1 << i) < 0) break;
        pre[u][i] = deps[deep - (1 << i)];
    }

    for (int i = he[u]; i != 0; i = nxt[i]) {
        int y = v[i];
        if (y == fa) continue;
        mn[y] = min(mn[u], (long long)w[i]);
        get_dfn(y, u, deep + 1);
    }
}

// --- lca

int lca(int x, int y) {
    if (dpn[x] < dpn[y]) swap(x, y);
    int kki = dpn[x] - dpn[y];
    for (int i = 0; i < 20; i++) {
        if ((kki >> i) & 1) {
            x = pre[x][i];
        }
        if (!(kki >> i)) break;
    }

    if (x == y) return y;

    for (int i = 24; i >= 0; i--) {
        if (pre[y][i] != pre[x][i]) {
            x = pre[x][i];
            y = pre[y][i];
        }
    }
    return pre[x][0];
}

// --- unreal tree

vector<int> E[N];
int stk[N], top;
void insert(int kki) {
    int x = stk[top];
    int y = stk[top - 1];
    int fa = lca(x, kki);
    if (dfn[x] == dfn[fa]) {
        stk[++top] = kki;
        return;
    }
    while (dfn[y] > dfn[fa]) {
        E[y].push_back(x);
        top--;
        x = stk[top];
        y = stk[top - 1];
    }
    if (dfn[y] == dfn[fa]) {
        E[y].push_back(x);
        stk[top] = kki;
    } else if (dfn[y] < dfn[fa]) {
        E[fa].push_back(x);
        stk[top] = fa;
        stk[++top] = kki;
    }
}

void erase(int u) {
    for (auto x : E[u]) {
        erase(x);
    }
    E[u].clear();
}

// --- calc

long long f[N];
bool is[N];

void calc(int u) {
    if (is[u]) {
        f[u] = 1ll << 60;
        return;
    }
    f[u] = 0;
    for (auto x : E[u]) {
        calc(x);
        f[u] += min(f[x], mn[x]);
    }
}

int seq[N];

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    int n;
    cin >> n;

    for (int i = 1; i < n; i++) {
        int x, y;
        long long val;
        cin >> x >> y >> val;
        add(x, y, val);
        add(y, x, val);
    }

    mn[1] = 1ll << 60;

    get_dfn(1);
    id = 0;

    int m;
    cin >> m;
    while (m--) {
        stk[0] = 1;
        int kki;
        cin >> kki;
        for (int i = 0; i < kki; i++) {
            cin >> seq[i];
            is[seq[i]] = true;
        }
        sort(seq, seq + kki, [](int x, int y) { return dfn[x] < dfn[y]; });
        stk[1] = seq[0];
        top = 1;
        for (int i = 1; i < kki; i++) insert(seq[i]);
        while (top) {
            E[stk[top - 1]].push_back(stk[top]);
            top--;
        }
        calc(1);
        cout << f[1] << '\n';
        erase(1);
        for (int i = 0; i < kki; i++) is[seq[i]] = false;
    }

    return 0;
}