题解:BISHI130 区间取反与区间数一

题目链接

区间取反与区间数一

题目描述

给定一个长度为 的二进制串 ,支持两类操作:

  • 区间取反:将 内所有位
  • 区间数一:查询 内字符为 的个数。

解题思路

线段树 + 懒标记(翻转标记)。

  • 结点维护当前区间内的 的数量
  • 懒标记 表示该区间尚未下推的“取反”操作。对一个结点取反时:,并将 取反。
  • 区间取反:命中整段直接翻转;否则下推后递归两子区间并回收。
  • 区间数一:标准区间和查询。

复杂度:每次操作 ,空间

代码

#include <bits/stdc++.h>
using namespace std;

struct SegTree {
    int n; string s;
    vector<int> cnt;        // 区间内 1 的个数
    vector<unsigned char> f; // 懒翻转标记(0/1)
    SegTree(int n=0): n(n), cnt(4*n+4,0), f(4*n+4,0) {}
    void build(int idx,int l,int r){
        if(l==r){ cnt[idx] = (s[l-1]=='1'); return; }
        int m=(l+r)>>1; build(idx<<1,l,m); build(idx<<1|1,m+1,r);
        cnt[idx]=cnt[idx<<1]+cnt[idx<<1|1];
    }
    inline void apply_flip(int idx,int l,int r){ cnt[idx] = (r-l+1)-cnt[idx]; f[idx]^=1; }
    inline void push(int idx,int l,int r){ if(!f[idx]) return; int m=(l+r)>>1; apply_flip(idx<<1,l,m); apply_flip(idx<<1|1,m+1,r); f[idx]=0; }
    void range_flip(int idx,int l,int r,int ql,int qr){
        if(ql<=l && r<=qr){ apply_flip(idx,l,r); return; }
        push(idx,l,r); int m=(l+r)>>1;
        if(ql<=m) range_flip(idx<<1,l,m,ql,qr);
        if(qr>m)  range_flip(idx<<1|1,m+1,r,ql,qr);
        cnt[idx]=cnt[idx<<1]+cnt[idx<<1|1];
    }
    int range_sum(int idx,int l,int r,int ql,int qr){
        if(ql<=l && r<=qr) return cnt[idx];
        push(idx,l,r); int m=(l+r)>>1; int ans=0;
        if(ql<=m) ans+=range_sum(idx<<1,l,m,ql,qr);
        if(qr>m)  ans+=range_sum(idx<<1|1,m+1,r,ql,qr);
        return ans;
    }
};

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n,q; if(!(cin>>n>>q)) return 0;
    string s; cin>>s;
    SegTree st(n); st.s=s; st.build(1,1,n);
    while(q--){
        int op,l,r; cin>>op>>l>>r;
        if(op==1){ st.range_flip(1,1,n,l,r); }
        else { cout<<st.range_sum(1,1,n,l,r)<<'\n'; }
    }
    return 0;
}
import java.io.*;

public class Main {
    static class FastScanner {
        private final InputStream in; private final byte[] buf = new byte[1<<16];
        private int p=0,l=0; FastScanner(InputStream is){in=is;}
        private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
        int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
        String next() throws IOException { int c; StringBuilder sb=new StringBuilder(); do{c=read();}while(c<=32); while(c>32){ sb.append((char)c); c=read(); } return sb.toString(); }
    }

    static class SegTree {
        int n; char[] s; int[] cnt; boolean[] flip;
        SegTree(int n){ this.n=n; cnt=new int[4*n+4]; flip=new boolean[4*n+4]; }
        void build(int idx,int l,int r){ if(l==r){ cnt[idx]=(s[l-1]=='1')?1:0; return; } int m=(l+r)>>1; build(idx<<1,l,m); build(idx<<1|1,m+1,r); cnt[idx]=cnt[idx<<1]+cnt[idx<<1|1]; }
        void applyFlip(int idx,int l,int r){ cnt[idx]=(r-l+1)-cnt[idx]; flip[idx]=!flip[idx]; }
        void push(int idx,int l,int r){ if(!flip[idx]) return; int m=(l+r)>>1; applyFlip(idx<<1,l,m); applyFlip(idx<<1|1,m+1,r); flip[idx]=false; }
        void rangeFlip(int idx,int l,int r,int ql,int qr){ if(ql<=l && r<=qr){ applyFlip(idx,l,r); return; } push(idx,l,r); int m=(l+r)>>1; if(ql<=m) rangeFlip(idx<<1,l,m,ql,qr); if(qr>m) rangeFlip(idx<<1|1,m+1,r,ql,qr); cnt[idx]=cnt[idx<<1]+cnt[idx<<1|1]; }
        int rangeSum(int idx,int l,int r,int ql,int qr){ if(ql<=l && r<=qr) return cnt[idx]; push(idx,l,r); int m=(l+r)>>1; int ans=0; if(ql<=m) ans+=rangeSum(idx<<1,l,m,ql,qr); if(qr>m) ans+=rangeSum(idx<<1|1,m+1,r,ql,qr); return ans; }
    }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        int q = fs.nextInt();
        String s = fs.next();
        SegTree st = new SegTree(n); st.s=s.toCharArray(); st.build(1,1,n);
        StringBuilder out = new StringBuilder();
        while(q-- > 0){
            int op = fs.nextInt();
            int l = fs.nextInt();
            int r = fs.nextInt();
            if(op==1) st.rangeFlip(1,1,n,l,r); else out.append(st.rangeSum(1,1,n,l,r)).append('\n');
        }
        System.out.print(out.toString());
    }
}
import sys

data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); q = int(next(it))
s = next(it).decode().strip()

size = 4*n + 5
cnt = [0]*size
flip = [0]*size  # 0/1 标记

def build(idx, l, r):
    if l == r:
        cnt[idx] = 1 if s[l-1] == '1' else 0
        return
    m = (l + r) >> 1
    build(idx<<1, l, m)
    build(idx<<1|1, m+1, r)
    cnt[idx] = cnt[idx<<1] + cnt[idx<<1|1]

def apply_flip(idx, l, r):
    cnt[idx] = (r - l + 1) - cnt[idx]
    flip[idx] ^= 1

def push(idx, l, r):
    if flip[idx] == 0:
        return
    m = (l + r) >> 1
    apply_flip(idx<<1, l, m)
    apply_flip(idx<<1|1, m+1, r)
    flip[idx] = 0

def range_flip(idx, l, r, ql, qr):
    if ql <= l and r <= qr:
        apply_flip(idx, l, r)
        return
    push(idx, l, r)
    m = (l + r) >> 1
    if ql <= m:
        range_flip(idx<<1, l, m, ql, qr)
    if qr > m:
        range_flip(idx<<1|1, m+1, r, ql, qr)
    cnt[idx] = cnt[idx<<1] + cnt[idx<<1|1]

def range_sum(idx, l, r, ql, qr):
    if ql <= l and r <= qr:
        return cnt[idx]
    push(idx, l, r)
    m = (l + r) >> 1
    ans = 0
    if ql <= m:
        ans += range_sum(idx<<1, l, m, ql, qr)
    if qr > m:
        ans += range_sum(idx<<1|1, m+1, r, ql, qr)
    return ans

build(1, 1, n)

out = []
for _ in range(q):
    op = int(next(it)); l = int(next(it)); r = int(next(it))
    if op == 1:
        range_flip(1, 1, n, l, r)
    else:
        out.append(str(range_sum(1, 1, n, l, r)))

sys.stdout.write('\n'.join(out))

算法及复杂度

  • 算法:线段树 + 懒翻转标记
  • 时间复杂度:
  • 空间复杂度: