E题
蒟蒻听直播听的不是特别懂,有很多疑惑,于是看了看dalao们的代码和解释,于是就有了这篇题解,侧重于本鶸不太懂得地方,可能对大家有帮助。

题目大意

给定一棵有n个节点的树,每个节点的权值​​​未知,满足​​​,其中​​​为给定数值。每条边的权值为已知,其值为它连接的两个边的权值的异或值。求出满足条件的​​​​的数量。(​)

思路

由于每条边的值为给定值,所以只要确定了一个点的值,剩下的点的值也就被确定了,遍历整个树需要​的时间,所以枚举​的点,也会TLE。所以我们不能通过枚举可行值来计算答案。

  1. 我们首先假设​,那么剩下所有的点都会被确定为

    那么假设我们把的值修改为了,那么剩下的点都为被改为

    显然剩下的点也需要满足

    假如不等式两边可以同时异或的话,那么上式就可以转化为

    那么问题就转换为,对于​个上式的不等式,求出​的可行解。

  2. 但是对于一个不等式,是不能够直接异或的,因为对于一个区间,异或上一个数后,得到的新的区间,不一定是连续的。注意用词,是不一定连续,假如我们可以把区间分解为几个在异或一个数之后,仍然为连续的区间的话,那么问题就会转变为**对于多个形如上式的不等式,求出​​​的可行解**。所以我们来寻找什么样的区间,可以满足这个条件。

    二进制下,只要满足000,001,010,011 .....111之间的每个数都存在,那么这个区间异或上一个数之后,还是一个连续的区间。(大家可以试一试)

    进一步来说,假如某个区间的所有数,在二进制下为n位,前面k位数都一样,且剩下的位数,刚好满足从0到2^(n - k+1) -1的每个数都有,那么这个区间异或上一个数之后也还为连续的,因为异或之后的每个数,前面k位依然相同,后面的数异或后也是一个连续的区间,那么连起来依然是一个连续的区间。

3.那么我们就可以统计以下每个区间异或之后的新区间,统计有哪些地方被覆盖的次数超过次(即满足 ​个不等式)。我们可以用类似于权值线段树的形式,即每个点代表当前这个值被覆盖的次数。为什么要用线段树呢?因为如果直接对整个区间进行统计的话,就需要从1统计到2^31-1,显然是无法通过空间限制的,但是用线段树统计的时候,我们可以用打标记,即类似于区间修改的方式来提前终止递归,从而降低了空间复杂度与时间复杂度。(一个范围为W的区间最多被分为​个部分)

4.那么代码主要就分为以下几步:

  • 先默认求出每个点的初始权值
  • 把每个区间分为异或后依然连续的区间(权值线段树操作)
  • 计算异或后的区间并作统计(差分)
  • 统计合法的结果数量(对差分数组排序,求前缀和)

代码

#include <cstdio>
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

const int maxN = 1e5 + 5;
int n, l[maxN], r[maxN], head[maxN], cnt, W[maxN], ans, sum;

struct SegmentTree {
    int val, laz;
    bool operator < (const SegmentTree &t)const {
        if(val != t.val)
            return val < t.val;
        return laz < t.laz;
    }
};

struct Edge {
    int from, to, w;
}e[maxN << 1];

inline void add(int u, int v, int w)
{
    e[++cnt].from = head[u];
    e[cnt].to = v;
    e[cnt].w = w;
    head[u] = cnt;
}

void dfs(int u,int fa,int val) { //第一步
    W[u]  =val;
    for(int i = head[u]; i; i = e[i].from) {
        int v =e [i].to,p =e [i].w;
        if(v != fa)
            dfs(v, u, val ^ p);
    }
}

vector<SegmentTree> V;
void operation(int l, int r, int val) //求出每个区间异或后的区间并作差分
{
    SegmentTree a1, a2;
    int len = r - l + 1;
    a1.val = (l ^ (val & (~(len - 1))));
    a2.val = a1.val + len;
    a1.laz = 1; a2.laz = -1;
    V.push_back(a1); V.push_back(a2);
}

void Search(int L,int R,int l,int r,int val) { //第二步:分区间
    if(L <= l && R >= r) {
        operation(l, r, val);
        return ;
    }
    int mid = (l + r) >> 1;
    if(L <= mid)
        Search(L, R, l, mid, val);
    if(R > mid)
        Search(L, R, mid + 1, r, val);
}

int main()
{
    scanf("%d", &n);
    for(int i = 1; i <= n; ++i)
        scanf("%d%d", &l[i], &r[i]);
    for(int i = 1; i < n; ++i) {
        int x, y, z;
        scanf("%d%d%d", &x, &y, &z);
        add(x, y, z); add(y, x, z);
    }
    dfs(1, 0, 0);
    for(int i = 1; i <= n; ++i)
        Search(l[i], r[i], 0, (1 << 30) - 1, W[i]);
    sort(V.begin(), V.end());//对差分数组进行差分
    int len = V.size(), sum = 0, ans = 0;
    for(int i = 0; i < len; ++i) { //求前缀和并做统计
        sum += V[i].laz;
        if(sum == n)
            ans += V[i + 1].val - V[i].val;
    }
    printf("%d\n", ans);
    return 0;
}