题意

给你一个长度为 \(n\) 的整数序列 \(a_1, a_2, \ldots, a_n\),你需要实现以下两种操作,每个操作都可以用四个整数 \(opt\ l\ r\ v\) 来表示:

  • \(opt=1\) 时,代表把一个区间 \([l, r]\) 内的所有数都 \(xor\)\(v\)

  • \(opt=2\) 时, 查询一个区间 \([l, r]\) 内选任意个数(包括 \(0\) 个)数 \(xor\) 起来,这个值与 \(v\) 的最大 \(xor\) 和是多少。

分析

线段树维护下线性基就行了,区间修改的时候记录下线段树每个结点的修改量\(k​\),合并的时候再加进线性基

因为线性基是构造出的一组极大线性无关组,所以查询\((a_i~xor~k)(i∈[l,r])\)组成的线性基等价于查询\(k∪a_i(i∈[l,r])​\)

Code

#include<bits/stdc++.h>
#define fi first
#define se second
#define bug cout<<"--------------"<<endl
using namespace std;
typedef long long ll;
const double PI=acos(-1.0);
const double eps=1e-6;
const int inf=1e9;
const ll llf=1e18;
const int mod=1e9+7;
const int maxn=5e4+10;
struct ji{
    int p[33],k;
    void clear(){
        memset(p,0,sizeof(p));
    }
    void insert(int x){
        for(int i=30;i>=0;i--){
            if(!((x>>i)&1)) continue;
            if(p[i]) x^=p[i];
            else{
                p[i]=x;
                break;
            }
        }
    }
    int qy(int x){
        int ret=x;
        for(int i=30;i>=0;i--) ret=max(ret^p[i],ret);
        return ret;
    }
};
int n,m;
int a[maxn],b[maxn],f[maxn],tag[maxn<<2];
ji tr[maxn<<2];
ji mer(ji a,ji b){
    ji ret=a;
    for(int i=30;i>=0;i--) if(b.p[i]) ret.insert(b.p[i]);
    ret.insert(ret.k^b.k);
    return ret;
}
void pushup(int p){
    tr[p]=mer(tr[p<<1],tr[p<<1|1]);
}
void tag1(int p,int x){
    tr[p].k^=x;
    tag[p]^=x;
}
void pushdown(int p){
    tag1(p<<1,tag[p]);
    tag1(p<<1|1,tag[p]);
    tag[p]=0;
}
void build(int l,int r,int p){
    if(l==r){
        scanf("%d",&tr[p].k);
        return;
    }
    int mid=l+r>>1;
    build(l,mid,p<<1);
    build(mid+1,r,p<<1|1);
    pushup(p);
}
void up(int dl,int dr,int l,int r,int p,int x){
    if(l>=dl&&r<=dr){
        tr[p].k^=x;
        tag[p]^=x;
        return;
    }
    pushdown(p);
    int mid=l+r>>1;
    if(dl<=mid) up(dl,dr,l,mid,p<<1,x);
    if(dr>mid) up(dl,dr,mid+1,r,p<<1|1,x);
    pushup(p);
}
ji ans;
void qy(int dl,int dr,int l,int r,int p){
    if(l>=dl&&r<=dr){
        ans=mer(ans,tr[p]);
        return; 
    }
    pushdown(p);
    int mid=l+r>>1;
    if(dl<=mid) qy(dl,dr,l,mid,p<<1);
    if(dr>mid) qy(dl,dr,mid+1,r,p<<1|1);
}
int main(){
    scanf("%d%d",&n,&m);
    build(1,n,1);
    while(m--){
        int op,l,r,v;
        scanf("%d%d%d%d",&op,&l,&r,&v);
        if(op==1){
            up(l,r,1,n,1,v);
        }else{
            ans.clear();
            qy(l,r,1,n,1);
            printf("%d\n",ans.qy(v));
        }
    }
    return 0;
}