牛客做的人比较少,建议在洛谷上做。
3个操作,2个查询,线段树模拟大题。来想一想要维护什么信息。首先, 和
在这道题中反复转换,所以每种信息肯定给
和
分别维护一份,这样在反转时直接
交换即可,用一个只有两个元素的数组搞定。
然后考虑查询 要维护的信息。 可以知道我们要维护每一个区间
和
的数目
,这还是比较好操作的,合并的时候直接加起来就好了,如果全部要变成
或
,就把这个数目改为区间长度。区间长度就等于
和
的数目之和 。
接着思考查询 要维护的信息,我们要维护最大连续的
和
的长度
,并且要让合并的时候也可以维护这个长度。根据上一道题的经验,我们可以维护两个从端点向中间扩展的最大长度*,这样合并的时候最大连续长度就可以取以下三者的最大值:
①左区间 。
②右区间 。
③“中间情况”:从左区间的右端点向左扩展的最大连续长度 从右区间的左端点向右扩展最大连续长度。
最后涉及到区间修改,我们还要有一个状态懒标记 ,我这里定义了它的四种情况:
① 表示没有发生过修改的状态。
② 表示全部覆盖为0的状态。
③ 表示全部覆盖为1的状态。
④ 表示0和1反转状态。
好了,剩下就是漫长的模拟了,需要细心地打下每一个函数,为 减轻负担 []~( ̄▽ ̄)~*。看到这里可以自己尝试了,如果有不清楚的可以看我的分部分(介绍顺序不一定是在主函数中定义的顺序)讲解。
首先是节点结构体定义和建树部分:
struct node{ //状态,0或1连续的长度,0或1的数量,从两边扩展0或1的最大长度 int state,len[2],cnt[2],l[2],r[2]; }t[maxn<<2]; void build(int now,int l,int r){ if(l == r){ bool b; cin >> b; //读入原序列0或1,!b表示取相反,0->1,1->0,也可以用异或 t[now].len[b] = t[now].cnt[b] = t[now].l[b] = t[now].r[b] = 1; t[now].len[!b] = t[now].cnt[!b] = t[now].l[!b] = t[now].r[!b] =0; t[now].state = -1; return; } build(ls,l,mid); build(rs,mid+1,r); pushup(now); }
然后是有点烦人的 ,注意我推荐是用一个
循环来节省代码量 ,因为
和
是处理是完全一样的,并且检查方便。我之前是复制后把0改成1,
了好多次才发现是哪里忘记改了。
void pushup(int now){ t[now].state = -1; int lenL,lenR; //左区间的长度,右区间的长度 lenL = t[ls].cnt[0] + t[ls].cnt[1]; lenR = t[rs].cnt[0] + t[rs].cnt[1]; for(int i = 0;i <= 1;i++){ //0或1的数量 t[now].cnt[i] = t[ls].cnt[i] + t[rs].cnt[i]; //连续的长度在三者取最大 t[now].len[i] = max(t[ls].len[i] ,max(t[rs].len[i],t[ls].r[i]+t[rs].l[i])); //0或1的延展长度 t[now].l[i] = t[ls].l[i]; t[now].r[i] = t[rs].r[i]; //判断是否左或右区间全是0或1 if(t[ls].l[i] == lenL) t[now].l[i] += t[rs].l[i]; if(t[rs].r[i] == lenR) t[now].r[i] += t[ls].r[i]; } }
然后是修改和 操作,更长,需要考虑到懒标记的变换了。
void update(int now,int l,int r,int x,int y,int op){ if(x <= l && r <= y){ //如果命令是全部变成0,或者命令是反转并且原来全是1,执行的结果是一样的 if(op == 0 || (t[now].state == 1 && op == 2)){ t[now].state = 0; t[now].cnt[0] = t[now].l[0] = t[now].r[0] = t[now].len[0] = r-l+1; t[now].cnt[1] =t[now].l[1] = t[now].r[1] = t[now].len[1] = 0; }else if(op == 1|| (t[now].state == 0 && op == 2)){ t[now].state = 1; t[now].cnt[1] = t[now].l[1] = t[now].r[1] = t[now].len[1] = r-l+1; t[now].cnt[0] =t[now].l[0] = t[now].r[0] = t[now].len[0] = 0; }else{ if(t[now].state == 2) t[now].state = -1; else t[now].state = 2; swap(t[now].l[0],t[now].l[1]);swap(t[now].r[0],t[now].r[1]); swap(t[now].cnt[0],t[now].cnt[1]);swap(t[now].len[0],t[now].len[1]); } return; } if(t[now].state != -1) pushdown(now, r - l + 1); if(x <= mid) update(ls,l,mid,x,y,op); if(y > mid) update(rs,mid+1,r,x,y,op); pushup(now); } void pushdown(int now,int len){ //全是0或者1状态 if(t[now].state == 0 || t[now].state == 1) for(int i = 0;i <= 1;i++){ if(t[now].state == i){ t[ls].state = t[rs].state = i; t[ls].len[i] = t[ls].r[i] = t[ls].l[i] = t[ls].cnt[i] = len-len/2; t[rs].len[i] = t[rs].r[i] = t[rs].l[i] = t[rs].cnt[i] = len/2; t[ls].len[i^1] = t[ls].r[i^1] = t[ls].l[i^1] = t[ls].cnt[i^1] = 0; t[rs].len[i^1] = t[rs].r[i^1] = t[rs].l[i^1] = t[rs].cnt[i^1] = 0; } } //全部反转状态 else{ //如果原来就处于反转状态了,就变成正常状态的-1 if(t[ls].state == 2) t[ls].state = -1; //如果原来是0或者1,就变成相反的状态,0变成1,1,变成0 else if(t[ls].state != -1)t[ls].state ^= 1; //否则,就是从正常状态变成反转状态2 else t[ls].state = 2; //右孩子同理处理 if(t[rs].state == 2) t[rs].state = -1; else if(t[rs].state != -1)t[rs].state ^= 1; else t[rs].state = 2; //全部交换,8个swap,左孩子4个,右孩子4个 swap(t[ls].l[0],t[ls].l[1]);swap(t[ls].r[0],t[ls].r[1]); swap(t[ls].cnt[0],t[ls].cnt[1]);swap(t[ls].len[0],t[ls].len[1]); swap(t[rs].l[0],t[rs].l[1]);swap(t[rs].r[0],t[rs].r[1]); swap(t[rs].cnt[0],t[rs].cnt[1]);swap(t[rs].len[0],t[rs].len[1]); } t[now].state = -1; }
然后是两个查询操作:
//查询区间[x,y]中1的数目 int query_tot(int now,int l,int r,int x,int y){ if(x <= l && r <= y) return t[now].cnt[1]; if(t[now].state != -1) pushdown(now, r - l + 1); int ans = 0; if(x <= mid) ans += query_tot(ls,l,mid,x,y); if(y > mid) ans+= query_tot(rs,mid+1,r,x,y); return ans; } //查询区间[x,y]中连续的1最长长度,返回节点 node query_len(int now, int l, int r, int x, int y){ if(x <= l && r <= y) return t[now]; if(t[now].state != -1) pushdown(now, r - l + 1); int lenL,lenR; node fa,lef,rig; if(x <= mid) lef = query_len(ls,l,mid,x,y); if(y > mid) rig = query_len(rs,mid+1,r,x,y); //和pushup合并类似 if( x <= mid && y > mid){ lenL = lef.cnt[0] + lef.cnt[1]; lenR = rig.cnt[0] + rig.cnt[1]; for(int i = 0;i <= 1;i++){ //0或1的数量 fa.cnt[i] = lef.cnt[i] + rig.cnt[i]; //连续的长度在三者取最大 fa.len[i] = max(lef.len[i] , max(rig.len[i], lef.r[i] + rig.l[i])); //0或1的延展长度 fa.l[i] = lef.l[i], fa.r[i] = rig.r[i]; //判断是否左或右区间全是0或1 if(lef.l[i] == lenL) fa.l[i] += rig.l[i]; if(rig.r[i] == lenR) fa.r[i] += lef.r[i]; } return fa; }else if(x <= mid) return lef; else if(y > mid) return rig; }
主函数比较简单,我直接给出完整代码了。
:
#include<bits/stdc++.h> using namespace std; #define For(i,sta,en) for(int i = sta;i <= en;i++) #define ls now<<1 #define rs now<<1|1 #define mid (l+r)/2 #define speedUp_cin_cout ios::sync_with_stdio(false);cin.tie(0); cout.tie(0); const int maxn = 1e5+9; struct node{ //状态,0或1连续的长度,0或1的数量,从两边扩展0或1的最大长度 int state,len[2],cnt[2],l[2],r[2]; }t[maxn<<2]; int n,m; void pushup(int now){ t[now].state = -1; int lenL,lenR; //左区间的长度,右区间的长度 lenL = t[ls].cnt[0] + t[ls].cnt[1]; lenR = t[rs].cnt[0] + t[rs].cnt[1]; for(int i = 0;i <= 1;i++){ //0或1的数量 t[now].cnt[i] = t[ls].cnt[i] + t[rs].cnt[i]; //连续的长度在三者取最大 t[now].len[i] = max(t[ls].len[i] ,max(t[rs].len[i],t[ls].r[i]+t[rs].l[i])); //0或1的延展长度 t[now].l[i] = t[ls].l[i]; t[now].r[i] = t[rs].r[i]; //判断是否左或右区间全是0或1 if(t[ls].l[i] == lenL) t[now].l[i] += t[rs].l[i]; if(t[rs].r[i] == lenR) t[now].r[i] += t[ls].r[i]; } } void pushdown(int now,int len){ //全是0或者1状态 if(t[now].state == 0 || t[now].state == 1) for(int i = 0;i <= 1;i++){ //全为0或者1 if(t[now].state == i){ t[ls].state = t[rs].state = i; t[ls].len[i] = t[ls].r[i] = t[ls].l[i] = t[ls].cnt[i] = len-len/2; t[rs].len[i] =t[rs].r[i] = t[rs].l[i] = t[rs].cnt[i] = len/2; t[ls].len[i^1] =t[ls].r[i^1] = t[ls].l[i^1] = t[ls].cnt[i^1] = 0; t[rs].len[i^1] =t[rs].r[i^1] = t[rs].l[i^1] = t[rs].cnt[i^1] = 0; } } //全部反转状态 else{ //如果原来就处于反转状态了,就变成正常状态的-1 if(t[ls].state == 2) t[ls].state = -1; //如果原来是0或者1,就变成相反的状态,0变成1,1,变成0 else if(t[ls].state != -1)t[ls].state ^= 1; //否则,就是从正常状态变成反转状态2 else t[ls].state = 2; //右孩子同理处理 if(t[rs].state == 2) t[rs].state = -1; else if(t[rs].state != -1)t[rs].state ^= 1; else t[rs].state = 2; //全部交换,8个swap,左孩子4个,右孩子4个 swap(t[ls].l[0],t[ls].l[1]);swap(t[ls].r[0],t[ls].r[1]); swap(t[ls].cnt[0],t[ls].cnt[1]);swap(t[ls].len[0],t[ls].len[1]); swap(t[rs].l[0],t[rs].l[1]);swap(t[rs].r[0],t[rs].r[1]); swap(t[rs].cnt[0],t[rs].cnt[1]);swap(t[rs].len[0],t[rs].len[1]); } t[now].state = -1; } void build(int now,int l,int r){ if(l == r){ bool b; cin >> b; //读入原序列0或1,!b表示取相反,0->1,1->0,也可以用异或 t[now].len[b] = t[now].cnt[b] = t[now].l[b] = t[now].r[b] = 1; t[now].len[!b] = t[now].cnt[!b] = t[now].l[!b] = t[now].r[!b] =0; t[now].state = -1; return; } build(ls,l,mid); build(rs,mid+1,r); pushup(now); } void update(int now,int l,int r,int x,int y,int op){ if(x <= l && r <= y){ //如果命令是全部变成0,或者命令是反转并且原来全是1,执行的结果是一样的 if(op == 0 || (t[now].state == 1 && op == 2)){ t[now].state = 0; t[now].cnt[0] = t[now].l[0] = t[now].r[0] = t[now].len[0] = r-l+1; t[now].cnt[1] =t[now].l[1] = t[now].r[1] = t[now].len[1] = 0; }else if(op == 1|| (t[now].state == 0 && op == 2)){ t[now].state = 1; t[now].cnt[1] = t[now].l[1] = t[now].r[1] = t[now].len[1] = r-l+1; t[now].cnt[0] =t[now].l[0] = t[now].r[0] = t[now].len[0] = 0; }else{ if(t[now].state == 2) t[now].state = -1; else t[now].state = 2; swap(t[now].l[0],t[now].l[1]);swap(t[now].r[0],t[now].r[1]); swap(t[now].cnt[0],t[now].cnt[1]);swap(t[now].len[0],t[now].len[1]); } return; } if(t[now].state != -1) pushdown(now, r - l + 1); if(x <= mid) update(ls,l,mid,x,y,op); if(y > mid) update(rs,mid+1,r,x,y,op); pushup(now); } //查询区间[x,y]中1的数目 int query_tot(int now,int l,int r,int x,int y){ if(x <= l && r <= y) return t[now].cnt[1]; if(t[now].state != -1) pushdown(now, r - l + 1); int ans = 0; if(x <= mid) ans += query_tot(ls,l,mid,x,y); if(y > mid) ans+= query_tot(rs,mid+1,r,x,y); return ans; } //查询区间[x,y]中连续的1最长长度 node query_len(int now, int l, int r, int x, int y){ if(x <= l && r <= y) return t[now]; if(t[now].state != -1) pushdown(now, r - l + 1); int lenL,lenR; node fa,lef,rig; if(x <= mid) lef = query_len(ls,l,mid,x,y); if(y > mid) rig = query_len(rs,mid+1,r,x,y); //和pushup合并类似 if( x <= mid && y > mid){ lenL = lef.cnt[0] + lef.cnt[1]; lenR = rig.cnt[0] + rig.cnt[1]; for(int i = 0;i <= 1;i++){ //0或1的数量 fa.cnt[i] = lef.cnt[i] + rig.cnt[i]; //连续的长度在三者取最大 fa.len[i] = max(lef.len[i] , max(rig.len[i], lef.r[i] + rig.l[i])); //0或1的延展长度 fa.l[i] = lef.l[i], fa.r[i] = rig.r[i]; //判断是否左或右区间全是0或1 if(lef.l[i] == lenL) fa.l[i] += rig.l[i]; if(rig.r[i] == lenR) fa.r[i] += lef.r[i]; } return fa; }else if(x <= mid) return lef; else if(y > mid) return rig; } int main(){ speedUp_cin_cout cin>>n>>m;int op,l,r; build(1,1,n); For(i,1,m){ cin>>op>>l>>r; if(op <= 2) update(1,1,n,l+1,r+1,op); else if(op == 3) cout<<query_tot(1,1,n,l+1,r+1)<<endl; else cout<<query_len(1,1,n,l+1,r+1).len[1]<<endl; } return 0; }