题解:BISHI130 区间取反与区间数一
题目链接
题目描述
给定一个长度为 的二进制串
,支持两类操作:
- 区间取反:将
内所有位
;
- 区间数一:查询
内字符为
的个数。
解题思路
线段树 + 懒标记(翻转标记)。
- 结点维护当前区间内的
的数量
。
- 懒标记
表示该区间尚未下推的“取反”操作。对一个结点取反时:
,并将
取反。
- 区间取反:命中整段直接翻转;否则下推后递归两子区间并回收。
- 区间数一:标准区间和查询。
复杂度:每次操作 ,空间
。
代码
#include <bits/stdc++.h>
using namespace std;
struct SegTree {
int n; string s;
vector<int> cnt; // 区间内 1 的个数
vector<unsigned char> f; // 懒翻转标记(0/1)
SegTree(int n=0): n(n), cnt(4*n+4,0), f(4*n+4,0) {}
void build(int idx,int l,int r){
if(l==r){ cnt[idx] = (s[l-1]=='1'); return; }
int m=(l+r)>>1; build(idx<<1,l,m); build(idx<<1|1,m+1,r);
cnt[idx]=cnt[idx<<1]+cnt[idx<<1|1];
}
inline void apply_flip(int idx,int l,int r){ cnt[idx] = (r-l+1)-cnt[idx]; f[idx]^=1; }
inline void push(int idx,int l,int r){ if(!f[idx]) return; int m=(l+r)>>1; apply_flip(idx<<1,l,m); apply_flip(idx<<1|1,m+1,r); f[idx]=0; }
void range_flip(int idx,int l,int r,int ql,int qr){
if(ql<=l && r<=qr){ apply_flip(idx,l,r); return; }
push(idx,l,r); int m=(l+r)>>1;
if(ql<=m) range_flip(idx<<1,l,m,ql,qr);
if(qr>m) range_flip(idx<<1|1,m+1,r,ql,qr);
cnt[idx]=cnt[idx<<1]+cnt[idx<<1|1];
}
int range_sum(int idx,int l,int r,int ql,int qr){
if(ql<=l && r<=qr) return cnt[idx];
push(idx,l,r); int m=(l+r)>>1; int ans=0;
if(ql<=m) ans+=range_sum(idx<<1,l,m,ql,qr);
if(qr>m) ans+=range_sum(idx<<1|1,m+1,r,ql,qr);
return ans;
}
};
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n,q; if(!(cin>>n>>q)) return 0;
string s; cin>>s;
SegTree st(n); st.s=s; st.build(1,1,n);
while(q--){
int op,l,r; cin>>op>>l>>r;
if(op==1){ st.range_flip(1,1,n,l,r); }
else { cout<<st.range_sum(1,1,n,l,r)<<'\n'; }
}
return 0;
}
import java.io.*;
public class Main {
static class FastScanner {
private final InputStream in; private final byte[] buf = new byte[1<<16];
private int p=0,l=0; FastScanner(InputStream is){in=is;}
private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
String next() throws IOException { int c; StringBuilder sb=new StringBuilder(); do{c=read();}while(c<=32); while(c>32){ sb.append((char)c); c=read(); } return sb.toString(); }
}
static class SegTree {
int n; char[] s; int[] cnt; boolean[] flip;
SegTree(int n){ this.n=n; cnt=new int[4*n+4]; flip=new boolean[4*n+4]; }
void build(int idx,int l,int r){ if(l==r){ cnt[idx]=(s[l-1]=='1')?1:0; return; } int m=(l+r)>>1; build(idx<<1,l,m); build(idx<<1|1,m+1,r); cnt[idx]=cnt[idx<<1]+cnt[idx<<1|1]; }
void applyFlip(int idx,int l,int r){ cnt[idx]=(r-l+1)-cnt[idx]; flip[idx]=!flip[idx]; }
void push(int idx,int l,int r){ if(!flip[idx]) return; int m=(l+r)>>1; applyFlip(idx<<1,l,m); applyFlip(idx<<1|1,m+1,r); flip[idx]=false; }
void rangeFlip(int idx,int l,int r,int ql,int qr){ if(ql<=l && r<=qr){ applyFlip(idx,l,r); return; } push(idx,l,r); int m=(l+r)>>1; if(ql<=m) rangeFlip(idx<<1,l,m,ql,qr); if(qr>m) rangeFlip(idx<<1|1,m+1,r,ql,qr); cnt[idx]=cnt[idx<<1]+cnt[idx<<1|1]; }
int rangeSum(int idx,int l,int r,int ql,int qr){ if(ql<=l && r<=qr) return cnt[idx]; push(idx,l,r); int m=(l+r)>>1; int ans=0; if(ql<=m) ans+=rangeSum(idx<<1,l,m,ql,qr); if(qr>m) ans+=rangeSum(idx<<1|1,m+1,r,ql,qr); return ans; }
}
public static void main(String[] args) throws Exception {
FastScanner fs = new FastScanner(System.in);
int n = fs.nextInt();
int q = fs.nextInt();
String s = fs.next();
SegTree st = new SegTree(n); st.s=s.toCharArray(); st.build(1,1,n);
StringBuilder out = new StringBuilder();
while(q-- > 0){
int op = fs.nextInt();
int l = fs.nextInt();
int r = fs.nextInt();
if(op==1) st.rangeFlip(1,1,n,l,r); else out.append(st.rangeSum(1,1,n,l,r)).append('\n');
}
System.out.print(out.toString());
}
}
import sys
data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); q = int(next(it))
s = next(it).decode().strip()
size = 4*n + 5
cnt = [0]*size
flip = [0]*size # 0/1 标记
def build(idx, l, r):
if l == r:
cnt[idx] = 1 if s[l-1] == '1' else 0
return
m = (l + r) >> 1
build(idx<<1, l, m)
build(idx<<1|1, m+1, r)
cnt[idx] = cnt[idx<<1] + cnt[idx<<1|1]
def apply_flip(idx, l, r):
cnt[idx] = (r - l + 1) - cnt[idx]
flip[idx] ^= 1
def push(idx, l, r):
if flip[idx] == 0:
return
m = (l + r) >> 1
apply_flip(idx<<1, l, m)
apply_flip(idx<<1|1, m+1, r)
flip[idx] = 0
def range_flip(idx, l, r, ql, qr):
if ql <= l and r <= qr:
apply_flip(idx, l, r)
return
push(idx, l, r)
m = (l + r) >> 1
if ql <= m:
range_flip(idx<<1, l, m, ql, qr)
if qr > m:
range_flip(idx<<1|1, m+1, r, ql, qr)
cnt[idx] = cnt[idx<<1] + cnt[idx<<1|1]
def range_sum(idx, l, r, ql, qr):
if ql <= l and r <= qr:
return cnt[idx]
push(idx, l, r)
m = (l + r) >> 1
ans = 0
if ql <= m:
ans += range_sum(idx<<1, l, m, ql, qr)
if qr > m:
ans += range_sum(idx<<1|1, m+1, r, ql, qr)
return ans
build(1, 1, n)
out = []
for _ in range(q):
op = int(next(it)); l = int(next(it)); r = int(next(it))
if op == 1:
range_flip(1, 1, n, l, r)
else:
out.append(str(range_sum(1, 1, n, l, r)))
sys.stdout.write('\n'.join(out))
算法及复杂度
- 算法:线段树 + 懒翻转标记
- 时间复杂度:
- 空间复杂度: