链接:https://ac.nowcoder.com/acm/contest/26896/1021
来源:牛客网

题目描述

lxhgww最近收到了一个01序列,序列里面包含了n个数,这些数要么是0,要么是1,现在对于这个序列有五种变换操作和询问操作:
0 a b 把[a, b]区间内的所有数全变成0
1 a b 把[a, b]区间内的所有数全变成1
2 a b 把[a,b]区间内的所有数全部取反,也就是说把所有的0变成1,把所有的1变成0
3 a b 询问[a, b]区间内总共有多少个1
4 a b 询问[a, b]区间内最多有多少个连续的1 
对于每一种询问操作,lxhgww都需要给出回答,聪明的程序员们,你们能帮助他吗?

输入描述:

输入数据第一行包括2个数,n和m,分别表示序列的长度和操作数目
第二行包括n个数,表示序列的初始状态
接下来m行,每行3个数,op, a, b,(0 ≤ op ≤ 4,0 ≤ a ≤ b)

输出描述:

对于每一个询问操作,输出一行,包括1个数,表示其对应的答案

题型:

线段树--区间修改与区间查询--多个tag维护

思路:

(P.S.:这是自己目前写过的40+篇题解里面,思路最长的一篇了,应该比较详尽了,适合和自己一样刚学线段树的小白食用~
这一题难度中等,主要是由于需要维护的tag有点多,以及tag的处理部分有些繁琐
先看需要维护的tag有哪一些:
1.lazy标记:记录当前区间的状态,如-1表示原来的状态,0表示全变0,1表示全变1,2表示取反一次
2.sum0,sum1:记录区间内含有0/1的个数
3.max0,max1:记录区间内含有连续的0/1的最大值
4.left0,left1:记录区间内最左边连续的0/1的个数
5.right0,right1:记录区间内最右边连续的0/1的个数
分开来算的话,一共有9个
但是,由于我们发现,像sumi,maxi,lefti,righti这四个tag都是记录i的,所以可以把这8个tag变成4个,即sum[2],max[2],left[2],right[2],这样之后也方便操作一些

然后是思考每一个tag应该如何从子节点更新到父节点
1.sum[2],这个其实就是类似于求区间和,父节点的值由子节点直接相加就好
2.max[2],这个有三种情况,一是最大值等于其左子树的连续的最大值,二是等于其右子树的连续的最大值,三是等于其左子树的最右边连续的值+其右子树的最左边连续的值,这三者取最大者即可
3.left[2],这个有两种情况,一是其值等于其左子树的left,二是等于其左子树的left+其右子树的left(前提是左子树的left==左子树的长度)
4.right[2],这个有两种情况,一是其值等于其右子树的right,二是等于其右子树的right+其左子树的right(前提是右子树的right==右子树的长度)
由此,代码中的update函数已经写好了


然后思考每一个tag在区间修改与查询时如何更新(即pushdown函数如何写)
1.lazy=0或1,即区间修改为0/1
这个简单,只要把sum[lazy]=max[lazy]=left[lazy]=right[lazy]=区间长度;sum[lazy^1]=max[lazy^1]=left[lazy^1]=right[lazy^1]=0;就行了,注意左右子树都要执行一次
之后再把父节点的lazy变成-1,子节点的lazy变成对应的0/1就行
2.lazy=2,即区间翻转
这个想明白其实也不难
我们知道,这里的区间翻转只涉及了0/1,也就是说,设某个区间长度为m,且此区间内sum[0]=n,那么sum[1]=m-n,区间翻转后,sum[0]=m-n,sum[1]=n,其实就相当于swap(sum[0],sum[1]);max,left,right同理
所以,对于sum,left,right,max四个tag,只需要执行一次swap操作就行(即代码中的Swap函数)
然后是lazy的修改(重要),考虑分类讨论
1.lazy=0,则改为1;
2.lazy=1,则改为0;
3.lazy=2,则改为-1;
4.lazy=-1,则改为2;
注意左右子树都要改,且不要忘了改完之后将父节点的值改为-1
至此,pushdown函数也搞定了

建树build函数直接套板子即可,注意一下这些tag的初始化就行
区间修改(change)函数也直接套板子即可,注意一下这些tag的改变就行
区间翻转(reverse)函数直接套区间修改的板子即可,注意一下这些tag的改变就行(尤其注意一下lazy的变化,下面代码里面有)
最后是区间查询,op=3时,查询区间和也直接套板子即可
主要是op=4时,查询连续的1的最大值
对于op=4的时候,可以先把区间和的板子套上去,然后把求和那一部分的代码改成求最大值即可
唯一注意的一点就是,上面说到过,max有三种情况,一是最大值等于其左子树的连续的最大值,二是等于其右子树的连续的最大值,三是等于其左子树的最右边连续的值+其右子树的最左边连续的值,这三者取最大者即可

至此,所有问题完美解决

最后一点需要注意的是,牛客上的样例数据的格式有问题,所以写出来的代码(包括下面这个代码,在牛客内的自测运行的时候会报格式错误,不用管这个,直接提交就行,如果WA了也不是格式的问题!!

至于每一个函数的作用,下面的代码内附有注释

代码:

#include<bits/stdc++.h>
using namespace std;
const int N=100200;
struct node {
	int lazy,sum[2],max[2],left[2],right[2];
	//lazy:0--全变0,1--全变1,2--翻转,-1--原样
} tree[N*4];
int a[N];

void update(int p,int l,int r) {  //更新结构体中的sum,left,right,max的值
	int mid=(l+r)/2;
	for(int i=0; i<=1; i++) {
		tree[p].sum[i]=tree[p*2].sum[i]+tree[p*2+1].sum[i];
		if(tree[p*2].left[i]<(mid-l+1)) tree[p].left[i]=tree[p*2].left[i];
		else tree[p].left[i]=tree[p*2].left[i]+tree[p*2+1].left[i];
		if(tree[p*2+1].right[i]<(r-(mid+1)+1)) tree[p].right[i]=tree[p*2+1].right[i];
		else tree[p].right[i]=tree[p*2+1].right[i]+tree[p*2].right[i];
		int tmp=tree[p*2].right[i]+tree[p*2+1].left[i];
		tree[p].max[i]=max(tree[p*2].max[i],max(tree[p*2+1].max[i],tmp));
	}
}

void build(int p,int l,int r) { //建树
	tree[p].lazy=-1;
	if(l==r) {
		tree[p].sum[a[l]]=tree[p].max[a[l]]=tree[p].left[a[l]]=tree[p].right[a[l]]=1;
		return ;
	}
	int mid=(l+r)/2;
	build(p*2,l,mid);
	build(p*2+1,mid+1,r);
	update(p,l,r);
}

void Swap(int p) { //交换0/1值的操作
	swap(tree[p].sum[0],tree[p].sum[1]);
	swap(tree[p].max[0],tree[p].max[1]);
	swap(tree[p].left[0],tree[p].left[1]);
	swap(tree[p].right[0],tree[p].right[1]);
}

void pushdown(int p,int l,int r,int mid) { //更新子节点
	if(tree[p].lazy==2) {
		Swap(p*2);
		Swap(p*2+1);
		if(tree[p*2].lazy==-1) {
			tree[p*2].lazy=2;
		} else if(tree[p*2].lazy==2) {
			tree[p*2].lazy=-1;
		} else tree[p*2].lazy^=1;
		if(tree[p*2+1].lazy==-1) {
			tree[p*2+1].lazy=2;
		} else if(tree[p*2+1].lazy==2) {
			tree[p*2+1].lazy=-1;
		} else tree[p*2+1].lazy^=1;
		tree[p].lazy=-1;
	} else {
		int a=tree[p].lazy;
		tree[p*2].sum[a]=tree[p*2].max[a]=tree[p*2].left[a]=tree[p*2].right[a]=mid-l+1;
		tree[p*2].sum[a^1]=tree[p*2].max[a^1]=tree[p*2].left[a^1]=tree[p*2].right[a^1]=0;
		tree[p*2+1].sum[a]=tree[p*2+1].max[a]=tree[p*2+1].left[a]=tree[p*2+1].right[a]=r-(mid+1)+1;
		tree[p*2+1].sum[a^1]=tree[p*2+1].max[a^1]=tree[p*2+1].left[a^1]=tree[p*2+1].right[a^1]=0;
		tree[p*2].lazy=tree[p*2+1].lazy=a;
		tree[p].lazy=-1;
	}
}

void change(int p,int l,int r,int x,int y,int op) { //执行区间修改为0/1操作
	if(x<=l && y>=r) { //找到了对应节点
		tree[p].lazy=op;
		tree[p].sum[op^1]=tree[p].max[op^1]=tree[p].left[op^1]=tree[p].right[op^1]=0;
		tree[p].sum[op]=tree[p].max[op]=tree[p].left[op]=tree[p].right[op]=(r-l+1);
		return;
	}
	int mid=(l+r)/2;
	if(tree[p].lazy!=-1) {
		pushdown(p,l,r,mid);
	}
	if(x<=mid) change(p*2,l,mid,x,y,op); //左子树与x-y区间有交集
	if(y>mid) change(p*2+1,mid+1,r,x,y,op); //右子树与x-y区间有交集
	update(p,l,r);
}

void Reverse(int p,int l,int r,int x,int y) { //执行区间翻转操作
	if(x<=l && y>=r) { //找到了对应节点
		Swap(p);
		if(tree[p].lazy==-1) {
			tree[p].lazy=2;
		} else if(tree[p].lazy==2) {
			tree[p].lazy=-1;
		} else tree[p].lazy^=1;
		return;
	}
	int mid=(l+r)/2;
	if(tree[p].lazy!=-1) {
		pushdown(p,l,r,mid);
	}
	if(x<=mid) Reverse(p*2,l,mid,x,y); //左子树与x-y区间有交集
	if(y>mid) Reverse(p*2+1,mid+1,r,x,y); //右子树与x-y区间有交集
	update(p,l,r);
}

int find1(int p,int l,int r,int x,int y){ //执行区间和操作
	if(x<=l && y>=r){ //l-r区间包含在x-y区间内部(完全包含,直接加上这个区间的值即可)
		return tree[p].sum[1];
	}
	int mid=(l+r)/2;
	if(tree[p].lazy!=-1) pushdown(p,l,r,mid);
	int ans=0;
	if(x<=mid) ans+=find1(p*2,l,mid,x,y); 
	if(y>=mid+1) ans+=find1(p*2+1,mid+1,r,x,y);
	return ans;
}

node find2(int p,int l,int r,int x,int y){ //执行区间内有多少个连续的1的操作
	if(x<=l && y>=r){ //l-r区间包含在x-y区间内部(完全包含,直接加上这个区间的值即可)
		return tree[p];
	}
	int mid=(l+r)/2;
	if(tree[p].lazy!=-1) pushdown(p,l,r,mid);
	if(y<=mid) return find2(p*2,l,mid,x,y); 
	if(x>mid) return find2(p*2+1,mid+1,r,x,y);
	node tmp1=find2(p*2,l,mid,x,mid);
	node tmp2=find2(p*2+1,mid+1,r,mid+1,y);
	tmp1.max[1]=max(tmp1.max[1],max(tmp2.max[1],tmp1.right[1]+tmp2.left[1]));
	if(tmp1.left[1]==mid-l+1) tmp1.left[1]+=tmp2.left[1];
	if(tmp2.right[1]==r-(mid+1)+1) tmp1.right[1]+=tmp2.right[1];
	else tmp1.right[1]=tmp2.right[1];
	return tmp1;
}

int main() {
	int n,m;
	scanf("%d%d",&n,&m);
	for(int i=1; i<=n; i++) {
		scanf("%d",&a[i]);
	}
	build(1,1,n);
	int op,a,b;
	while(m--) {
		scanf("%d%d%d",&op,&a,&b);
		a++;
		b++;
		if(op==0) {
			change(1,1,n,a,b,op);
		}
		if(op==1) {
			change(1,1,n,a,b,op);
		}
		if(op==2) {
			Reverse(1,1,n,a,b);
		}
		if(op==3) {
			printf("%d\n",find1(1,1,n,a,b));
		}
		if(op==4) {
			printf("%d\n",find2(1,1,n,a,b).max[1]);
		}
	}
	return 0;
}