题目
算法标签: 树形 d p dp dp, 树上倍增, L C A LCA LCA
思路
在原问题基础上还有每个点的限制, 要求某些点必须选择, 某些节点不能选择

将问题简化, 如果是一维问题, 可以前后缀分解来做


将问题回到树形问题, 首先考虑只有一个点选或者不选, 可以将整个树分为两部分, 也是类似于前后缀分解的方法满足一个点的限制


对于两个点以上情况, 使用倍增求解

对于两个点来说选与不选是两种情况, 然后对于 l c a lca lca也有两种情况, 选择或者不选择
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10, M = N << 1, K = 17;
const LL INF = 1e18;
int n, m, p[N];
string type;
vector<int> head[N];
LL f[N][2], g[N][2], w[N][K][2][2];
int fa[N][K], depth[N];
void add(int u, int v) {
head[u].push_back(v);
}
void dfs_f(int u, int father) {
f[u][1] = p[u];
for (int v: head[u]) {
if (v == father) continue;
dfs_f(v, u);
f[u][0] += f[v][1];
f[u][1] += min(f[v][0], f[v][1]);
}
}
void dfs_g(int u, int father) {
for (int v: head[u]) {
if (v == father) continue;
g[v][0] = g[u][1] + f[u][1] - min(f[v][0], f[v][1]);
g[v][1] = min(g[v][0], g[u][0] + f[u][0] - f[v][1]);
dfs_g(v, u);
}
}
void dfs_fa(int u, int father) {
fa[u][0] = father;
for (int k = 1; k < K; ++k)
fa[u][k] = fa[fa[u][k - 1]][k - 1];
for (int v: head[u]) {
if (v == father) continue;
depth[v] = depth[u] + 1;
dfs_fa(v, u);
}
}
void dfs_w(int u, int father) {
for (int v: head[u]) {
if (v == father) continue;
w[v][0][0][0] = INF;
w[v][0][0][1] = f[u][1] - min(f[v][0], f[v][1]);
w[v][0][1][0] = f[u][0] - f[v][1];
w[v][0][1][1] = f[u][1] - min(f[v][0], f[v][1]);
for (int k = 1; k < K; ++k) {
int anc = fa[v][k - 1];
for (int x = 0; x < 2; ++x) {
for (int y = 0; y < 2; ++y) {
w[v][k][x][y] = INF;
for (int z = 0; z < 2; ++z) {
w[v][k][x][y] = min(w[v][k][x][y], w[v][k - 1][x][z] + w[anc][k - 1][z][y]);
}
}
}
}
dfs_w(v, u);
}
}
int lca(int a, int b) {
if (depth[a] < depth[b]) swap(a, b);
for (int k = K - 1; k >= 0; --k)
if (depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if (a == b) return a;
for (int k = K - 1; k >= 0; --k)
if (fa[a][k] != fa[b][k])
a = fa[a][k], b = fa[b][k];
return fa[a][0];
}
LL solve(int a, int x, int b, int y) {
if (depth[a] < depth[b]) swap(a, b), swap(x, y);
if (!x && !y && fa[a][0] == b) return -1;
LL sa[2] = {
INF, INF}, sb[2] = {
INF, INF};
sa[x] = f[a][x];
sb[y] = f[b][y];
for (int k = K - 1; k >= 0; --k) {
if (depth[fa[a][k]] >= depth[b]) {
LL na[2] = {
INF, INF};
for (int u = 0; u < 2; ++u) {
for (int v = 0; v < 2; ++v) {
na[v] = min(na[v], sa[u] + w[a][k][u][v]);
}
}
memcpy(sa, na, sizeof na);
a = fa[a][k];
}
}
if (a == b) return sa[y] + g[b][y];
for (int k = K - 1; k >= 0; --k) {
if (fa[a][k] != fa[b][k]) {
LL na[2] = {
INF, INF}, nb[2] = {
INF, INF};
for (int u = 0; u < 2; ++u) {
for (int v = 0; v < 2; ++v) {
na[v] = min(na[v], sa[u] + w[a][k][u][v]);
nb[v] = min(nb[v], sb[u] + w[b][k][u][v]);
}
}
memcpy(sa, na, sizeof na);
memcpy(sb, nb, sizeof nb);
a = fa[a][k];
b = fa[b][k];
}
}
int l = fa[a][0];
LL res0 = f[l][0] - f[a][1] - f[b][1] + sa[1] + sb[1] + g[l][0];
LL res1 = f[l][1] - min(f[a][0], f[a][1]) - min(f[b][0], f[b][1])
+ min(sa[0], sa[1]) + min(sb[0], sb[1]) + g[l][1];
return min(res0, res1);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n >> m >> type;
for (int i = 1; i <= n; ++i) cin >> p[i];
for (int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
dfs_f(1, 0);
dfs_g(1, 0);
depth[1] = 1;
dfs_fa(1, 0);
// 初始化w数组
for (int i = 1; i <= n; ++i)
for (int k = 0; k < K; ++k)
for (int x = 0; x < 2; ++x)
for (int y = 0; y < 2; ++y)
w[i][k][x][y] = INF;
dfs_w(1, 0);
while (m--) {
int a, x, b, y;
cin >> a >> x >> b >> y;
cout << solve(a, x, b, y) << "\n";
}
return 0;
}


京公网安备 11010502036488号