E题
蒟蒻听直播听的不是特别懂,有很多疑惑,于是看了看dalao们的代码和解释,于是就有了这篇题解,侧重于本鶸不太懂得地方,可能对大家有帮助。
题目大意
给定一棵有n个节点的树,每个节点的权值未知,满足
,其中
为给定数值。每条边的权值为已知,其值为它连接的两个边的权值的异或值。求出满足条件的
的数量。(
)
思路
由于每条边的值为给定值,所以只要确定了一个点的值,剩下的点的值也就被确定了,遍历整个树需要的时间,所以枚举
的点,也会TLE。所以我们不能通过枚举可行值来计算答案。
我们首先假设
,那么剩下所有的点都会被确定为
那么假设我们把
的值修改为了
,那么剩下的点都为被改为
。
显然剩下的点也需要满足
。
假如不等式两边可以同时异或的话,那么上式就可以转化为
那么问题就转换为,对于
个上式的不等式,求出
的可行解。
但是对于一个不等式,是不能够直接异或的,因为对于一个区间,异或上一个数后,得到的新的区间,不一定是连续的。注意用词,是不一定连续,假如我们可以把区间分解为几个在异或一个数之后,仍然为连续的区间的话,那么问题就会转变为**对于多个形如上式的不等式,求出
的可行解**。所以我们来寻找什么样的区间,可以满足这个条件。
二进制下,只要满足000,001,010,011 .....111之间的每个数都存在,那么这个区间异或上一个数之后,还是一个连续的区间。(大家可以试一试)
进一步来说,假如某个区间的所有数,在二进制下为n位,前面k位数都一样,且剩下的位数,刚好满足从0到2^(n - k+1) -1的每个数都有,那么这个区间异或上一个数之后也还为连续的,因为异或之后的每个数,前面k位依然相同,后面的数异或后也是一个连续的区间,那么连起来依然是一个连续的区间。
3.那么我们就可以统计以下每个区间异或之后的新区间,统计有哪些地方被覆盖的次数超过次(即满足
个不等式)。我们可以用类似于权值线段树的形式,即每个点代表当前这个值被覆盖的次数。为什么要用线段树呢?因为如果直接对整个区间进行统计的话,就需要从1统计到2^31-1,显然是无法通过空间限制的,但是用线段树统计的时候,我们可以用打标记,即类似于区间修改的方式来提前终止递归,从而降低了空间复杂度与时间复杂度。(一个范围为W的区间最多被分为
个部分)
4.那么代码主要就分为以下几步:
- 先默认
求出每个点的初始权值
- 把每个区间分为异或后依然连续的区间(权值线段树操作)
- 计算异或后的区间并作统计(差分)
- 统计合法的结果数量(对差分数组排序,求前缀和)
代码
#include <cstdio> #include <iostream> #include <vector> #include <algorithm> using namespace std; const int maxN = 1e5 + 5; int n, l[maxN], r[maxN], head[maxN], cnt, W[maxN], ans, sum; struct SegmentTree { int val, laz; bool operator < (const SegmentTree &t)const { if(val != t.val) return val < t.val; return laz < t.laz; } }; struct Edge { int from, to, w; }e[maxN << 1]; inline void add(int u, int v, int w) { e[++cnt].from = head[u]; e[cnt].to = v; e[cnt].w = w; head[u] = cnt; } void dfs(int u,int fa,int val) { //第一步 W[u] =val; for(int i = head[u]; i; i = e[i].from) { int v =e [i].to,p =e [i].w; if(v != fa) dfs(v, u, val ^ p); } } vector<SegmentTree> V; void operation(int l, int r, int val) //求出每个区间异或后的区间并作差分 { SegmentTree a1, a2; int len = r - l + 1; a1.val = (l ^ (val & (~(len - 1)))); a2.val = a1.val + len; a1.laz = 1; a2.laz = -1; V.push_back(a1); V.push_back(a2); } void Search(int L,int R,int l,int r,int val) { //第二步:分区间 if(L <= l && R >= r) { operation(l, r, val); return ; } int mid = (l + r) >> 1; if(L <= mid) Search(L, R, l, mid, val); if(R > mid) Search(L, R, mid + 1, r, val); } int main() { scanf("%d", &n); for(int i = 1; i <= n; ++i) scanf("%d%d", &l[i], &r[i]); for(int i = 1; i < n; ++i) { int x, y, z; scanf("%d%d%d", &x, &y, &z); add(x, y, z); add(y, x, z); } dfs(1, 0, 0); for(int i = 1; i <= n; ++i) Search(l[i], r[i], 0, (1 << 30) - 1, W[i]); sort(V.begin(), V.end());//对差分数组进行差分 int len = V.size(), sum = 0, ans = 0; for(int i = 0; i < len; ++i) { //求前缀和并做统计 sum += V[i].laz; if(sum == n) ans += V[i + 1].val - V[i].val; } printf("%d\n", ans); return 0; }