牛客—— 红球进黑洞 (线段树+位运算)

铭宇巨巨推荐的题!

原题链接

题意:

给定一个序列,两种操作,一是区间求和,二是将区间里的每个数都异或x。

思路:

一眼就线段树,关键是怎么维护第二个操作。

借助最小异或生成树的思想以及异或题的常见套路,我们可以把每个数都进行二进制拆分,用线段树分别维护每一位上的0和1。

对于操作二,我们只需要分别维护每一位的值即可;对于操作一,计算区间里1的个数乘以对应的位数,就相当于是二进制转化为十进制的过程。

代码:

#pragma GCC optimize(3)
#pragma GCC optimize("Ofast","unroll-loops","omit-frame-pointer","inline")
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll,ll>PLL;
typedef pair<int,int>PII;
typedef pair<double,double>PDD;
#define I_int ll
#define x first
#define y second
inline ll read()
{
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
char F[200];
inline void out(I_int x) {
    if (x == 0) return (void) (putchar('0'));
    I_int tmp = x > 0 ? x : -x;
    if (x < 0) putchar('-');
    int cnt = 0;
    while (tmp > 0) {
        F[cnt++] = tmp % 10 + '0';
        tmp /= 10;
    }
    while (cnt > 0) putchar(F[--cnt]);
    //cout<<" ";
}
ll ksm(ll a,ll b,ll p){ll res=1;while(b){if(b&1)res=res*a%p;a=a*a%p;b>>=1;}return res;}
const int inf=0x3f3f3f3f,mod=1e9+7;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const int maxn=1e5+10,maxm=3e6+7;
const double PI = atan(1.0)*4;

struct node{
    int l,r,laz,sum;
}tr[25][maxn*4];
int a[maxn];
int n,m,k;

void pushup(int cnt,int u){
    tr[cnt][u].sum=tr[cnt][u<<1].sum+tr[cnt][u<<1|1].sum;
    tr[cnt][u].laz=0;
    return ;
}

void pushdown(int cnt,int u){
    if(tr[cnt][u].laz){
        tr[cnt][u].laz=0;
        tr[cnt][u<<1].laz^=1;tr[cnt][u<<1|1].laz^=1;
        tr[cnt][u<<1].sum=tr[cnt][u<<1].r-tr[cnt][u<<1].l+1-tr[cnt][u<<1].sum;
        tr[cnt][u<<1|1].sum=tr[cnt][u<<1|1].r-tr[cnt][u<<1|1].l+1-tr[cnt][u<<1|1].sum;
    }
    ///异或后1的个数 = 原来0的个数 = 区间长度(总个数)- 异或前1的个数
}

void build(int cnt,int u,int l,int r){
    tr[cnt][u]={l,r,0,0};
    if(l==r){
        tr[cnt][u].sum=(a[l]>>cnt)&1;
        return ;
    }
    int mid=(l+r)>>1;
    build(cnt,u<<1,l,mid);build(cnt,u<<1|1,mid+1,r);
    pushup(cnt,u);
}
///update(i,1,ql,qr,k&1);
void update(int cnt,int u,int ql,int qr,int x){
    if(tr[cnt][u].l>=ql&&tr[cnt][u].r<=qr){
        tr[cnt][u].sum=tr[cnt][u].r-tr[cnt][u].l+1-tr[cnt][u].sum;
        tr[cnt][u].laz^=x;
        return ;
    }
    pushdown(cnt,u);
    int mid=(tr[cnt][u].l+tr[cnt][u].r)>>1;
    if(ql<=mid) update(cnt,u<<1,ql,qr,x);
    if(qr>mid) update(cnt,u<<1|1,ql,qr,x);
    pushup(cnt,u);
}
///qask(i,1,ql,qr);
ll qask(int cnt,int u,int ql,int qr){
    if(tr[cnt][u].l>=ql&&tr[cnt][u].r<=qr) return tr[cnt][u].sum;
    pushdown(cnt,u);
    ll res=0;
    int mid=(tr[cnt][u].l+tr[cnt][u].r)>>1;
    if(ql<=mid) res+=qask(cnt,u<<1,ql,qr);
    if(qr>mid) res+=qask(cnt,u<<1|1,ql,qr);
    return res;
}

int main(){
    n=read(),m=read();
    for(int i=1;i<=n;i++) a[i]=read();
    for(int i=0;i<=20;i++) build(i,1,1,n);
    while(m--){
        int op=read(),ql=read(),qr=read();
        if(op==1){
            ll res=0;
            for(int i=0;i<=20;i++) res+=(1<<i)*qask(i,1,ql,qr);
            out(res);puts("");
        }
        else{
            int k=read();
            for(int i=0;i<=20;i++){
                if(k&1) update(i,1,ql,qr,k&1);
                k>>=1;
            }
        }
    }
    return 0;
}

参考博客 https://blog.csdn.net/luyehao1/article/details/84196619