问题:给定一个长度为n的序列(1<=n<=1e6),一个整数k(1<=k<=1e9),让你求出这个序列里有几个不同的区间满足区间内的数异或和大于等于k。
思路:转化前缀异或和。然后利用二进制和字典树解决。
一段区间的异或和x[l,r]=x[1,l-1]^x[1,r]。所以问题转化为找这个序列里 有几对序列前缀异或和异或起来的值是大于等于k的。记s[i]为前i个数的异或和。我们可以对于每一个i在[1,i-1]的范围里找有几个j,满足s[j]^s[i]>=k。
然后我们就可以用字典树来解决这个问题了。对于每一个i,先将s[i-1]按高位到低位存入字典树。每个结点存一个cnt,代表有几个前缀是这样的。假设s[i]^A = k,则A=k^s[i]。我们在字典树里按从高位到低位去跑A,如果k在跑到的当前这一位上为0,那么答案就加上与s[i]当前这一位相反的结点上的值(这些必定是和s[i]异或起来大于k的)。最后跑完A后加上最后一个结点的值(这些是和s[i]异或起来等于k的值)。直接看可能挺难理解的,在纸上模拟几遍可能就懂了。
AC代码:
//#include<bits/stdc++.h> #include<set> #include<map> #include<stack> #include<cmath> #include<ctime> #include<queue> #include<cstdio> #include<string> #include<vector> #include<cstdlib> #include<cstring> #include<iostream> #include<algorithm> #include<unordered_map> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<int, int> pii; typedef pair<ll, ll> pll; typedef pair<ll, int> pli; typedef pair<int, ll> pil; #define pb push_back #define X first #define Y second inline ll gcd(ll a, ll b) { while (b != 0) { ll c = a % b; a = b; b = c; }return a < 0 ? -a : a; } inline ll lcm(ll a, ll b) { return (a * b) / gcd(a, b); } inline ll lowbit(ll x) { return x & (-x); } const double PI = 3.14159265358979323846; const int inf = 0x3f3f3f3f; const ll INF = 0x3f3f3f3f3f3f3f3f; const ll mod = 998244353; inline char nc() { static char buf[1 << 21], * p1 = buf, * p2 = buf; return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++; } inline ll rd() { //#define getchar nc ll x = 0, f = 1; char ch = getchar(); while (ch<'0' || ch>'9') { if (ch == '-')f = -1; ch = getchar(); } while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); } return x * f; } const double eps = 1e-6; const int M = 1e6 + 10; const int N = 1e6 + 10; int trie[N * 32][2]; int a[N]; int tot = 1; int node[N * 32]; int k; void insert(int x) { int p = 1; for (int i = 30; i >= 0; i--) { int ch = (x >> i) & 1; if (!trie[p][ch])trie[p][ch] = ++tot; p = trie[p][ch]; node[p]++; } } ll search(int x) { int p = 1; ll sum = 0; for (int i = 30; i >= 0; i--) { int ch = (x >> i) & 1; if (((k >> i) & 1) == 0)sum += node[trie[p][ch ^ 1]]; p = trie[p][ch]; //if (p == 0)return sum; if (i == 0)sum += node[p]; } return sum; } int main() { int n = rd(); k = rd(); for (int i = 1; i <= n; i++)a[i] = a[i - 1] ^ rd(); ll ans = 0; for (int i = 1; i <= n; i++) { insert(a[i - 1]); ans += search(a[i] ^ k); } cout << ans << endl; }