G 魔法树

这题第一感觉就是树形DP

但是这里面的状态蛮多的,不太好下手。

其实可以分情况讨论

  • 全是奇数
  • 全是偶数

这样的话,dfs中需要维护的状态反而少了。

还有一点就是,这题的核心是边

也就是边的选择/不选决定了方案数

只有理解这点,才能理解下面的转移方程

  1. 全是奇数连通分量

每个节点,定义如下两个状态(s1, s2)

  • s1表示包含该节点的连通图 累加和为 奇数

  • s2表示包含该节点的连通图 累加和为 偶数

那状态转移为

s1' = s1 * t1 + s1 * t2 + s2 * t1;

s2' = s1 * t1 + s2 * t1 + s2 * t2;


  1. 全是偶数的连通分量

每个节点,定义如下两个状态(k1, k2)

  • k1表示包含该节点的连通图 累加和为 奇数

  • k2表示包含该节点的连通图 累加和为 偶数

那状态转移为

k1' = k1 * t2 * 2 + k2 * t1;

k2' = k1 * t1 + k2 * t2 * 2;

所以最终结果为

s1 + k2

import java.io.*;
import java.util.*;

public class Main {

    static class Solution {

        static long mod = 998244353l;

        int n;
        int[] ws;
        List<Integer> []g;

        long solve(int n, int[] ws, List<Integer> []g) {
            this.n = n;
            this.ws = ws;
            this.g = g;

            long[] r1 = dfs1(0, -1);

            long[] r2 = dfs2(0, -1);

            return (r1[0] + r2[1]) % mod;
        }

        // 树形DP
        long[] dfs1(int u, int fa) {
            long x1 = 0, x2 = 0;

            if (ws[u] % 2 == 0) x2 = 1;
            else x1 = 1;

            for (int v: g[u]) {
                if (v == fa) continue;

                long[] cs = dfs1(v, u);
                long t1 = cs[0], t2 = cs[1];
                
                long x3 = x1 * t1 % mod + x1 * t2 % mod + x2 * t1 % mod;
                long x4 = x1 * t1 % mod + x2 * t1 % mod + x2 * t2 % mod;

                x1 = x3 % mod;
                x2 = x4 % mod;
            }

            return new long[] {x1, x2};
        }

        long[] dfs2(int u, int fa) {
            long x1 = 0, x2 = 0;

            if (ws[u] % 2 == 0) x2 = 1;
            else x1 = 1;

            for (int v: g[u]) {
                if (v == fa) continue;

                long[] cs = dfs2(v, u);
                long t1 = cs[0], t2 = cs[1];

                long x3 = x1 * t2 % mod * 2 % mod + x2 * t1 % mod;
                long x4 = x1 * t1 % mod + x2 * t2 % mod * 2 % mod;

                x1 = x3 % mod;
                x2 = x4 % mod;
            }

            return new long[] {x1, x2};
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(new BufferedInputStream(System.in));

        int n = sc.nextInt();
        int[] ws = new int[n];
        for (int i = 0 ;i < n; i++) {
            ws[i] = sc.nextInt() % 2;
        }

        List<Integer>[]g = new List[n];
        Arrays.setAll(g, x -> new ArrayList<>());

        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt() - 1, v = sc.nextInt() - 1;
            g[u].add(v);
            g[v].add(u);
        }

        Solution solution = new Solution();
        long r = solution.solve(n, ws, g);
        System.out.println(r);
    }
    
}