题面:
题意:
有两个长度分别为 n 和 m 的数组 a 和 b,以及一个线性基 S,我们定义两个数组 x和 y 匹配当且仅当这两个数组长度相等且 i=1∑len[xi⨁yi∈S]=len,即两个数组相对应的数异或之后的结果在线性基 S 中。求出a中所有与b匹配的区间。
官方题解:
考虑满足 ai⨁bi∈S 的条件,令 x 和 y 为 S 能够表示的两个数,上式可以写为 ∃x,y∈S,ai⨁bi=x⨁y,等价于ai⨁x=bi⨁y (两边同时 ⨁bi⨁x)。
我们考虑用线性基将 ai 和 bi 中能够消除的部分消掉,那么余下的部分就是无法被线性基表示的部分,显然只有当 ai 和 bi 无法被表示的部分相等时, ai⨁bi 才能被表示。
ai 消去 B 中的位得到的 ai′,那么 ai′ 中不包含线性基 B 中的位。
同理, bi′ 中 也不包含线性基 B 中的位。如果 ai′=bi′,那么 ai′⨁bi′=0,且 ai′⨁bi′ 不包含 B 中的位(即存在 B 无法表示的位)。那么 ai⨁bi∈/S。
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<unordered_set>
#include<set>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
//#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
#define fhead(x) for(int i=head[(x)];i;i=nt[i])
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const double alpha=0.75;
const int mod=1e9+7;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=200100;
const int maxm=100100;
const int maxp=100100;
const int up=29;
int a[maxn],b[maxn],s[maxn],p[32];
void _insert(int val)
{
for(int i=up;i>=0;i--)
{
if((val>>i)&1)
{
if(!p[i])
{
p[i]=val;
break;
}
else val^=p[i];
}
}
}
void creat(int n)
{
memset(p,0,sizeof(p));
for(int i=1;i<=n;i++)
_insert(s[i]);
}
int ask(int val)
{
for(int i=up;i>=0;i--)
{
if((val>>i)&1)
{
if(p[i])
val^=p[i];
}
}
return val;
}
ll mypow(ll a,ll b)
{
ll ans=1;
while(b)
{
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
int nt[maxn];
ll getkmp(int n,int m)
{
nt[1]=0;
for(int i=2,j=0;i<=m;i++)
{
while(j>0&&b[i]!=b[j+1]) j=nt[j];
if(b[i]==b[j+1]) j++;
nt[i]=j;
}
ll ans=0;
for(int i=1,j=0;i<=n;i++)
{
while(j>0&&(j==m||a[i]!=b[j+1])) j=nt[j];
if(a[i]==b[j+1]) j++;
if(j==m) ans=(ans+mypow(2,i-m))%mod;
}
return ans;
}
int main(void)
{
int tt;
scanf("%d",&tt);
while(tt--)
{
int n,m,k;
scanf("%d%d%d",&n,&m,&k);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(int i=1;i<=m;i++)
scanf("%d",&b[i]);
for(int i=1;i<=k;i++)
scanf("%d",&s[i]);
creat(k);
for(int i=1;i<=n;i++)
a[i]=ask(a[i]);
for(int i=1;i<=m;i++)
b[i]=ask(b[i]);
printf("%lld\n",getkmp(n,m));
}
return 0;
}