题目链接
题目大意:
给一段数,进行三种操作
1、将一段区间所有数减去ai&-ai
2、将一段区间所有数加上2^k<= ai <2^(k+1)
3、区间和
总体思路:
求区间和可以直接用普通线段树来维护,但是操作1和2都不是线段树支持的区间操作。
对于操作1,我们可以采取直接暴力更新,但是我们会发现一个数减去一定次数的lowbit就会变为0,在0上进行更新操作是没有意义的,所以定义一个falg=1表示这段区间所有数都为0,那么这段区间就不用被更新了。
对于操作2,我们发现每次加上2^k,其实就是将最高位乘以2,所以我们可以将一个数分为高位和低位两部分来处理。每次更新操作2,其实将区间高位整体乘以2,这样就支持线段操作了。
后面就是线段树的基本操作了。
#include<bits/stdc++.h> using namespace std; typedef long long ll; const long long mod=998244353; int t; int n,m; int x; int ch,l,r; struct ty{ int l,r; ll la,lb;//la表示高位,lb表示低位 int flag;//判断区间是不是都为0 ll lazy;//最高位乘以的数 ty():l(0),r(0),la(0),lb(0),flag(0),lazy(1){} }; int a[100005],b[100005]; ty tr[100005*4]; void pushup(int u){ tr[u].la=(tr[u<<1].la+tr[u<<1|1].la)%mod; tr[u].lb=(tr[u<<1].lb+tr[u<<1|1].lb)%mod; tr[u].flag=tr[u<<1].flag&tr[u<<1|1].flag; return ; } void build(int u,int l,int r){ tr[u].l=l; tr[u].r=r; tr[u].lazy=1; tr[u].flag=0; if(l==r){ tr[u].la=a[l]%mod; tr[u].lb=b[l]%mod; return ; } int mind=(l+r)>>1; build(u<<1,l,mind); build(u<<1|1,mind+1,r); pushup(u); return ; } void pushdown(int u){ if(tr[u].lazy>1){ tr[u<<1].la=tr[u<<1].la*tr[u].lazy%mod; tr[u<<1|1].la=tr[u<<1|1].la*tr[u].lazy%mod; tr[u<<1].lazy=tr[u<<1].lazy*tr[u].lazy%mod; tr[u<<1|1].lazy=tr[u<<1|1].lazy*tr[u].lazy%mod; tr[u].lazy=1; } return ; } void modify1(int u,int l,int r){ if(tr[u].l==tr[u].r){ if(tr[u].lb){ tr[u].lb-=(tr[u].lb&-tr[u].lb); return ; } else{ tr[u].la=0; tr[u].flag=1; return ; } } pushdown(u); int mind=(tr[u].l+tr[u].r)>>1; if(l<=mind&&!tr[u<<1].flag) modify1(u<<1,l,r); if(r>mind&&!tr[u<<1|1].flag) modify1(u<<1|1,l,r); pushup(u); return ; } void modify2(int u,int l,int r){ if(tr[u].l>=l&&tr[u].r<=r){ tr[u].la=tr[u].la*2%mod; tr[u].lazy=tr[u].lazy*2%mod; return ; } pushdown(u); int mind=(tr[u].l+tr[u].r)>>1; if(l<=mind&&!tr[u<<1].flag) modify2(u<<1,l,r); if(r>mind&&!tr[u<<1|1].flag) modify2(u<<1|1,l,r); pushup(u); return ; } ll query(int u,int l,int r){ if(tr[u].l>=l&&tr[u].r<=r){ return (tr[u].la+tr[u].lb)%mod; } pushdown(u); int mind=(tr[u].l+tr[u].r)>>1; ll ans=0; if(l<=mind&&!tr[u<<1].flag) ans=(ans+query(u<<1,l,r))%mod; if(r>mind&&!tr[u<<1|1].flag) ans=(ans+query(u<<1|1,l,r))%mod; return ans; } int main(){ scanf(" %d",&t); while(t--){ scanf(" %d",&n); for(int i=1;i<=n;i++){ scanf(" %d",&x); for(int j=33;j>=0;j--){ if((1ll<<j)<=x){//**记得为1ll**,不然1<<j爆int会为0导致错误 a[i]=1ll<<j; b[i]=x-(1ll<<j); break; } } } build(1,1,n); scanf(" %d",&m); for(int i=1;i<=m;i++){ scanf(" %d %d %d",&ch,&l,&r); if(ch==1) printf("%lld\n",query(1,l,r)); else if(ch==2) modify1(1,l,r); else modify2(1,l,r); } } return 0; }