G 魔法树
这题第一感觉就是树形DP
但是这里面的状态蛮多的,不太好下手。
其实可以分情况讨论
- 全是奇数
- 全是偶数
这样的话,dfs中需要维护的状态反而少了。
还有一点就是,这题的核心是边
也就是边的选择/不选决定了方案数
只有理解这点,才能理解下面的转移方程
- 全是奇数连通分量
每个节点,定义如下两个状态(s1, s2)
-
s1表示包含该节点的连通图 累加和为 奇数
-
s2表示包含该节点的连通图 累加和为 偶数
那状态转移为
s1' = s1 * t1 + s1 * t2 + s2 * t1;
s2' = s1 * t1 + s2 * t1 + s2 * t2;
- 全是偶数的连通分量
每个节点,定义如下两个状态(k1, k2)
-
k1表示包含该节点的连通图 累加和为 奇数
-
k2表示包含该节点的连通图 累加和为 偶数
那状态转移为
k1' = k1 * t2 * 2 + k2 * t1;
k2' = k1 * t1 + k2 * t2 * 2;
所以最终结果为
s1 + k2
import java.io.*;
import java.util.*;
public class Main {
static class Solution {
static long mod = 998244353l;
int n;
int[] ws;
List<Integer> []g;
long solve(int n, int[] ws, List<Integer> []g) {
this.n = n;
this.ws = ws;
this.g = g;
long[] r1 = dfs1(0, -1);
long[] r2 = dfs2(0, -1);
return (r1[0] + r2[1]) % mod;
}
// 树形DP
long[] dfs1(int u, int fa) {
long x1 = 0, x2 = 0;
if (ws[u] % 2 == 0) x2 = 1;
else x1 = 1;
for (int v: g[u]) {
if (v == fa) continue;
long[] cs = dfs1(v, u);
long t1 = cs[0], t2 = cs[1];
long x3 = x1 * t1 % mod + x1 * t2 % mod + x2 * t1 % mod;
long x4 = x1 * t1 % mod + x2 * t1 % mod + x2 * t2 % mod;
x1 = x3 % mod;
x2 = x4 % mod;
}
return new long[] {x1, x2};
}
long[] dfs2(int u, int fa) {
long x1 = 0, x2 = 0;
if (ws[u] % 2 == 0) x2 = 1;
else x1 = 1;
for (int v: g[u]) {
if (v == fa) continue;
long[] cs = dfs2(v, u);
long t1 = cs[0], t2 = cs[1];
long x3 = x1 * t2 % mod * 2 % mod + x2 * t1 % mod;
long x4 = x1 * t1 % mod + x2 * t2 % mod * 2 % mod;
x1 = x3 % mod;
x2 = x4 % mod;
}
return new long[] {x1, x2};
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(new BufferedInputStream(System.in));
int n = sc.nextInt();
int[] ws = new int[n];
for (int i = 0 ;i < n; i++) {
ws[i] = sc.nextInt() % 2;
}
List<Integer>[]g = new List[n];
Arrays.setAll(g, x -> new ArrayList<>());
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt() - 1, v = sc.nextInt() - 1;
g[u].add(v);
g[v].add(u);
}
Solution solution = new Solution();
long r = solution.solve(n, ws, g);
System.out.println(r);
}
}