题号 NC200547
名称 划分树
来源 牛客练习赛57

给出一棵 n 个点的树,点编号 1..n , i 号点的点权是 aii 。

可以通过删边的方式将这棵树划分成一些连通块,求有多少种不同的划分方案,满足:划分后每个连通块的点权异或和均为 M 。

答案对 1004535809 取模。

样例

输入
3 0
1 1
0 0 0
输出
4

算法

(树形dp)

*难题同时也是好题

首先给出本道题动态规划的状态表示

然后我们需要给出以下的四个性质:(性质就解释了为什么这样定义状态表示)

  1. 只有1号节点的总权值异或和为0或者M时才有划分方案,否则答案为0
  2. 如果以非1号节点的为根节点的子树的权值和为0/M时v与其父节点之间的边是可以删除的
  3. 如果以非1号节点的为根节点的子树的权值和不为0/M时v与其父节点之间的边是不可以删除的
  4. 如果以非1号节点的为根节点的子树的权值和不为0/M时,删掉其子树中连向权值和为M的子树的边的方案数对最终答案也有可能是有贡献的

性质1:

​ 因为一个合法方案的结果得到的连通块的权值异或和都为M,那么我们根据删除边的方式将一个个连通块分成两个连通块同时权值异或和相异或

​ 如果删除偶数条边那么等于异或0,如果删除奇数条边那么等于异或M,包含1号节点的连通块的权值异或和不为M,所以只有1号节点的总权值异或和为0或者M时答案才有划分方案

性质2:

​ 如果删除的边连向的子树权值异或和为0,则删掉这条边后,再从被分离的子树中删除奇数条连向权值和为M的子树的边就是一个合法的划分

​ 如果删除的边连向的子树权值异或和为M,则删掉这条边后,再从被分离的子树中删除偶数条连向权值和为M的子树的边就是一个合法的划分

​ 所以可以将这个子树划分成一个单独的连通块,所以与其父节点之间的边是可以删除的

性质3:

​ 由于当前节点为根的子树的权值和即不为0又不为M,所以在这个子树中删除偶数条边后(等于异或了偶数个M)结果不是0或者M:

​ 在这个子树中删除奇数条边后(等于异或了偶数个M)结果依然不是0或者M,所以不能将这个子树划分成一个单独的块,

​ 所以与其父节点之间的边是不可以删除的

性质4:

​ 根据性质3,当前子树不能划分成一个单独的连通块,但是当他的某个祖先节点的权值异或和为0/M时

​ 那么删除这个子树中连向权值异或和为M的子树的边的方案是合法的,所以性质4是成立的


过程:
首先每个节点为根节点的子树权值异或和我们可以用dfs预处理出来

然后我们考虑状态转移:

初值:

  1. 如果节点u为根的子树的权值异或和为0/M,则令 (删除0条边,是删除偶数条边) ,令
  2. 如果节点u为根的子树的权值异或和不为0/M,则令 (删除0条边,是删除偶数条边) ,令

中间状态:

假设当前计算到u节点

我们先将u分为,u为根的子树权值异或和为0/M 或者 不为0/M两种情况

在每一个情况中依次考虑其子节点对方案的贡献(类似于树形背包问题的划分方式)

  1. 节点为根的子树权值异或和为0/M:

    1. 其子节点,以为根节点的子树权值异或和为0/M:

    2. 其子节点,以为根节点的子树权值异或和不为0/M:

      (根据性质4)

      (根据性质4)

    所有子节点考虑完毕后考虑连向父节点的边

    1. 节点为根的子树权值异或和为0 (u 不为1号节点):

      (根据性质2的讨论)

    2. 节点为根的子树权值异或和为M (u 不为1号节点):

      (根据性质2的讨论)

  2. 节点为根的子树权值异或和不为0/M:

    1. 其子节点,以为根节点的子树权值异或和为0/M:

    2. 其子节点,以为根节点的子树权值异或和不为0/M:


答案:

当整棵树的权值异或和为0 且 M == 0时输出:

当整棵树的权值异或和为0 且 M != 0时输出:

当整棵树的权值异或和为M 且 M != 0时输出:

否则输出:

细节:
1.当m == 0,且当前子树的权值异或和为0时,连向父节点的情况两种都成立,所以需要用一个变量分别记录f[u][0] ,f[u][1]
具体看代码注释
2.实际上如果u为根的子树权值和为0/M时g数组是为0的,u为根的子树权值和不为0/M时f数组是为0的,所以我们不用判断子节点的情况
具体看代码

我们发现f和g数组的区别只有是否考虑连向父节点的边的情况,所以我们可以将这两个数组合并,这就是官方题解的代码

时间复杂度

C++ 代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <unordered_map>
#include <map>
#include <vector>
#include <queue>
#include <set>
#include <bitset>
#include <cmath>

#define P 131

#define lc u << 1
#define rc u << 1 | 1

using namespace std;
typedef long long LL;
const int N = 500010,mod = 1004535809;
int h[N],ne[N * 2],e[N * 2],idx;
LL w[N];
LL sum[N],f[N][2],g[N][2];
int n;
LL m;

void add(int a,int b)
{
    e[idx] = b,ne[idx] = h[a],h[a] = idx ++;
}

void dfs1(int u,int father)
{
    sum[u] = w[u];
    for(int i = h[u];~i;i = ne[i])
    {
        int j = e[i];
        if(j == father) continue;
        dfs1(j,u);
        sum[u] ^= sum[j];
    }
}

void dfs2(int u,int father)
{
    if(sum[u] == 0 || sum[u] == m) f[u][0] = 1;
    else g[u][0] = 1;
    LL f0 = 0,f1 = 0,g0 = 0,g1 = 0;
    for(int i = h[u];~i;i = ne[i])
    {
        int j = e[i];
        if(j == father) continue;
        dfs2(j,u);
        f0 = f[u][0] * f[j][0] % mod + f[u][1] * f[j][1] % mod + f[u][0] * g[j][0] % mod + f[u][1] * g[j][1] % mod;
        f1 = f[u][0] * f[j][1] % mod + f[u][1] * f[j][0] % mod + f[u][0] * g[j][1] % mod + f[u][1] * g[j][0] % mod;

        g0 = g[u][0] * f[j][0] % mod + g[u][1] * f[j][1] % mod + g[u][0] * g[j][0] % mod + g[u][1] * g[j][1] % mod;
        g1 = g[u][1] * f[j][0] % mod + g[u][0] * f[j][1] % mod + g[u][0] * g[j][1] % mod + g[u][1] * g[j][0] % mod;

        f[u][0] = f0 % mod;
        f[u][1] = f1 % mod;
        g[u][0] = g0 % mod;
        g[u][1] = g1 % mod;
    }
    f0 = f[u][0],f1 = f[u][1];
    // 当m == 0 时下面两个判断都会成立所以需要用f0,f1分别记录
    if(sum[u] == 0 && u != 1) f[u][0] = (f[u][0] + f1) % mod;
    if(sum[u] == m && u != 1) f[u][1] = (f[u][1] + f0) % mod;
}

void solve()
{
    scanf("%d%lld",&n,&m);
    memset(h,-1,sizeof h);
    for(int i = 2;i <= n;i ++)
    {
        int x;
        scanf("%d",&x);
        add(x,i);
        add(i,x);
    }
    for(int i = 1;i <= n;i ++) scanf("%lld",&w[i]);
    dfs1(1,-1);
    dfs2(1,-1);
    if(m == 0 && sum[1] == 0) printf("%lld\n",(f[1][1] + f[1][0]) % mod);
    else if(sum[1] == 0) printf("%lld\n",f[1][1] % mod);
    else if(sum[1] == m) printf("%lld\n",f[1][0] % mod);
    else printf("0\n");
}

int main()
{
    #ifdef LOCAL
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    #else
    #endif // LOCAL
    int T = 1;
    // init(500);
    // scanf("%d",&T);
    while(T --)
    {
        // scanf("%lld%lld",&n,&m);
        solve();
        // test();
    }
    return 0;
}