题目描述

Z 国有n座城市,n − 1条双向道路,每条双向道路连接两座城市,且任意两座城市
都能通过若干条道路相互到达。
Z 国的国防部长小 Z 要在城市中驻扎军队。驻扎军队需要满足如下几个条件:
1. 一座城市可以驻扎一支军队,也可以不驻扎军队。
2. 由道路直接连接的两座城市中至少要有一座城市驻扎军队。
3. 在城市里驻扎军队会产生花费,在编号为i的城市中驻扎军队的花费是pi。
小 Z 很快就规划出了一种驻扎军队的方案,使总花费最小。但是国王又给小 Z 提出
了m个要求,每个要求规定了其中两座城市是否驻扎军队。小 Z 需要针对每个要求逐一 给出回答。具体而言,如果国王提出的第j个要求能够满足上述驻扎条件(不需要考虑 第 j 个要求之外的其它要求),则需要给出在此要求前提下驻扎军队的最小开销。如果 国王提出的第j个要求无法满足,则需要输出-1 (1 ≤ j ≤ m)。现在请你来帮助小 Z。

输入描述:

第 1 行包含两个正整数𝑛, 𝑚和一个字符串𝑡𝑦𝑝𝑒,分别表示城市数、要求数和数据类 型。𝑡𝑦𝑝𝑒是一个由大写字母 A,B 或 C 和一个数字 1,2,3 组成的字符串。它可以帮助 你获得部分分。你可能不需要用到这个参数。这个参数的含义在【数据规模与约定】中 有具体的描述。
第2行n个整数pi,表示编号i的城市中驻扎军队的花费。
接下来n − 1行,每行两个正整数u, v,表示有一条u到v的双向道路。 接下来m行,第j行四个整数a, x, b, y(a ≠ b),表示第j个要求是在城市a驻扎x支军队,
在城市b驻扎y支军队。其中,x 、 y 的取值只有 0 或 1:若 x 为 0,表示城市 a 不得驻 扎军队,若 x 为 1,表示城市 a 必须驻扎军队;若 y 为 0,表示城市 b 不得驻扎军队, 若 y 为 1,表示城市 b 必须驻扎军队。 输入文件中每一行相邻的两个数据之间均用一个空格分隔。

输出描述:

输出共m行,每行包含 1 个整数,第j行表示在满足国王第j个要求时的最小开销, 如果无法满足国王的第j个要求,则该行输出-1。

示例1

输入
5 3 C3
2 4 1 3 9
1 5
5 2
5 3
3 4
1 0 3 0
2 1 3 1
1 0 5 0
输出
12
7
-1
说明
对于第一个要求,在 4 号和 5 号城市驻扎军队时开销最小。
对于第二个要求,在 1 号、2 号、3 号城市驻扎军队时开销最小。 第三个要求是无法满足的,因为在 1 号、5 号城市都不驻扎军队就意味着由道路直接连 接的两座城市中都没有驻扎军队。

备注

对于 100%的数据,n, m ≤ 300000,1 ≤ pi ≤ 100000。
数据类型的含义:
A:城市i与城市i + 1直接相连。
B:任意城市与城市 1 的距离不超过 100(距离定义为最短路径上边的数量),即如果这 棵树以 1 号城市为根,深度不超过 100。
C:在树的形态上无特殊约束。
1:询问时保证a = 1, x = 1,即要求在城市 1 驻军。对b, y没有限制。
2:询问时保证a, b是相邻的(由一条道路直接连通)
3:在询问上无特殊约束。

解答

我们可以通过树形 在线性时间内求出一个点如果颜色为 ,那么整棵树的最小代价为 (具体的做法就是先从下往上树形 得出点 如果选 这个颜色的话整个子树中的最小代价为 ,然后从上往下 得出点 的父亲如果选 这个颜色的话,以为根 父亲子树中的最小代价为 ,具体细节不在此详细赘述)。

发现如果固定两个点 的颜色分别为 ,那么就应该对于树上 路径上(不包含 )的每一个点分别考虑是否染黑。因为 信息是可减的,所以如果那条链上的染色的方案已经确定下来了,我们容易算出总代价:考虑路径上从上往下连续的三个点 颜色为,那么 的贡献就是 ,其中 表示能够和 相临的颜色。我们把这个贡献算在 这条边上。
考虑如何确定最优的链上染色方案。对于树上从上往下的两条链,其中一条链顶端的父亲是另一条链的底端,我们要合并这两条链的信息。发现 转移只和链的两端的颜色有关,所以对于一条链只需记录它两边是否染黑即可。合并的时候枚举相邻两点的颜色,如果不全为 则合法。于是,我们考虑倍增。令表示 点向上长度为  的链,顺序为从下往上或者从上往下的 值。这样就可以通过倍增转移,询问时像查询 一样查询即可。时间复杂度
代码实现
细节较多,注意特判询问时某个点在另一个点子树中的情况。
#include <cstdio>
#include <algorithm>
using namespace std;

typedef long long llong;
const int maxn = 1e5, maxm = 2e5, logn = 16; const llong infl = 1e18 + 1e9 + 1;
int n, m, a[maxn + 3], tot, ter[maxm + 3], nxt[maxm + 3], lnk[maxn + 3];
int dep[maxn + 3], cnt, l[maxn + 3], r[maxn + 3], fa[maxn + 3][logn + 3];
llong dp1[maxn + 3][2], dp2[maxn + 3][2], f[maxn + 3][2];

inline void upd_min(llong &a, llong b) {
    a = min(a, b);
}

struct node {
    llong dp[2][2];
    node() { dp[0][0] = dp[0][1] = dp[1][0] = dp[1][1] = infl; }
    llong get_min() {
        llong ans = infl;
        for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) {
            upd_min(ans, dp[i][j]);
        }
        return ans;
    }
    friend inline node merge(const node &a, const node &b) {
        node c;
        for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) {
            for (int k = 0; k < 2; k++) for (int l = 0; l < 2; l++) {
                if (k || l) upd_min(c.dp[i][j], a.dp[i][k] + b.dp[l][j]);
            }
        }
        return c;
    }
} g[maxn + 3][logn + 3][2];

void add_edge(int u, int v) {
    ter[++tot] = v;
    nxt[tot] = lnk[u];
    lnk[u] = tot;
}

void dfs1(int u, int p) {
    dp1[u][1] = a[u];
    for (int e = lnk[u], v; e; e = nxt[e]) {
        if ((v = ter[e]) == p) continue;
        dfs1(v, u);
        dp1[u][0] += dp1[v][1];
        dp1[u][1] += min(dp1[v][0], dp1[v][1]);
    }
}

void dfs2(int u, int p) {
    for (int e = lnk[u], v; e; e = nxt[e]) {
        if ((v = ter[e]) == p) continue;
        dp2[v][0] = dp2[u][1] + dp1[u][0] - dp1[v][1];
        dp2[v][1] = min(dp2[u][0], dp2[u][1]) + dp1[u][1] - min(dp1[v][0], dp1[v][1]);
        dfs2(v, u);
    }
    f[u][0] = dp2[u][1];
    f[u][1] = a[u] + min(dp2[u][0], dp2[u][1]);
    for (int e = lnk[u], v; e; e = nxt[e]) {
        if ((v = ter[e]) == p) continue;
        f[u][0] += dp1[v][1];
        f[u][1] += min(dp1[v][0], dp1[v][1]);
    }
}

void dfs3(int u, int p) {
    dep[u] = dep[p] + 1, fa[u][0] = p;
    l[u] = r[u] = ++cnt;
    llong A = f[p][0] - dp1[u][1] - dp2[p][1];
    llong B = f[p][1] - min(dp1[u][0], dp1[u][1]) - min(dp2[p][0], dp2[p][1]);
    g[u][0][0].dp[0][0] = A, g[u][0][0].dp[1][1] = B;
    g[u][0][1].dp[0][0] = A, g[u][0][1].dp[1][1] = B;
    for (int i = 0, t; (t = fa[fa[u][i]][i]); i++) {
        fa[u][i + 1] = t;
        g[u][i + 1][0] = merge(g[u][i][0], g[fa[u][i]][i][0]);
        g[u][i + 1][1] = merge(g[fa[u][i]][i][1], g[u][i][1]);
    }
    for (int e = lnk[u], v; e; e = nxt[e]) {
        if ((v = ter[e]) == p) continue;
        dfs3(v, u), r[u] = r[v];
    }
}

llong solve(int u, int a, int v, int b) {
    if (dep[u] > dep[v]) swap(u, v), swap(a, b);
    node A, B;
    B.dp[b][b] = f[v][b] - (!b ? dp2[v][1] : min(dp2[v][0], dp2[v][1]));
    if (l[u] <= l[v] && l[v] <= r[u]) {
        int diff = dep[v] - dep[u] - 1;
        for (int i = 0; i <= logn; i++) {
            if (diff >> i & 1) {
                B = merge(g[v][i][1], B);
                v = fa[v][i];
            }
        }
        A.dp[a][a] = f[u][a] - (!a ? dp1[v][1] : min(dp1[v][0], dp1[v][1]));
        return merge(A, B).get_min();
    }
    A.dp[a][a] = f[u][a] - (!a ? dp2[u][1] : min(dp2[u][0], dp2[u][1]));
    int diff = dep[v] - dep[u];
    for (int i = 0; i <= logn; i++) {
        if (diff >> i & 1) {
            B = merge(g[v][i][1], B);
            v = fa[v][i];
        }
    }
    if (fa[u][0] == fa[v][0]) goto next_part;
    for (int i = logn; ~i; i--) {
        if (fa[u][i] != fa[v][i]) {
            A = merge(A, g[u][i][0]);
            B = merge(g[v][i][1], B);
            u = fa[u][i], v = fa[v][i];
        }
    }
    next_part:;
    node C; int x = fa[u][0];
    C.dp[0][0] = f[x][0] - dp1[u][1] - dp1[v][1];
    C.dp[1][1] = f[x][1] - min(dp1[u][0], dp1[u][1]) - min(dp1[v][0], dp1[v][1]);
    return merge(A, merge(C, B)).get_min();
}

int main() {
    scanf("%d %d %*s", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    for (int i = 1, u, v; i < n; i++) {
        scanf("%d %d", &u, &v);
        add_edge(u, v), add_edge(v, u);
    }
    dfs1(1, 0);
    dfs2(1, 0);
    /*
    for (int i = 1; i <= n; i++) {
        for (int j = 0; j < 2; j++) {
            printf("%d %d %lld\n", i, j, f[i][j]);
        }
    }
    */
    dfs3(1, 0);
    for (int a, x, b, y; m--; ) {
        scanf("%d %d %d %d", &a, &x, &b, &y);
        llong ret = solve(a, x, b, y);
        printf("%lld\n", ret == infl ? -1 : ret);
    }
    return 0;
}

/*
5 3 C3
2 4 1 3 9
1 5
5 2
5 3
3 4
1 0 3 0
2 1 3 1
1 0 5 0

5 1 C3
1 1 1 1 1
1 2
1 3
2 4
3 5
4 1 5 1
*/


来源:Galaxy Coder