没有思路直接上代码
#include <iostream>
#include <cstdio>
#define lson u << 1
#define rson u << 1 | 1
using namespace std;
const int N = 1e5 + 10;
int n, m;
int a[N];
struct Node
{
int l, r;
int lazy;//0表示全变0 1表示全变1 2表示全变取反 -1表示什么都不干
int sum[2], ma[2], pre[2], suf[2];
//sum[i]总共用多少个i ma[i]最多多少个连续的i
//pre[i]前缀中最多多少个连续的i suf[i]后缀中最多多少个连续的i
}tr[N << 2];
template <class T>
inline void read(T & res)
{
char ch; bool flag = false;
while ((ch = getchar()) < '0' || ch > '9')
if (ch == '-') flag = true;
res = ch ^ 48;
while ((ch = getchar()) >= '0' && ch <= '9')
res = (res << 3) + (res << 1) + (ch ^ 48);
if (flag) res = ~res + 1;
}
void pushup(Node &u, Node &x, Node &y)
{
//考虑到下面会使用mid - u.l + 1以此来求区间长度
//但是这里的mid千万不能写成 int mid = u.l + u.r >> 1;
//因为这里的u节点里面不一定有东西,query的else里面pushup的是一个空的u
//所以直接使用x.r - x.l + 1来求区间长度
for (int i = 0; i < 2; i ++)
{
u.sum[i] = x.sum[i] + y.sum[i];
u.ma[i] = max(max(x.ma[i], y.ma[i]), x.suf[i] + y.pre[i]);
//这里直接压行
u.pre[i] = x.pre[i] + y.pre[i] * (x.pre[i] == x.r - x.l + 1);
u.suf[i] = y.suf[i] + x.suf[i] * (y.suf[i] == y.r - y.l + 1);
/*不压行的话
u.pre[i] = x.pre[i];
if (x.pre[i] == x.r - x.l + 1) u.pre[i] = x.pre[i] + y.pre[i];
u.suf[i] = y.suf[i];
if (y.suf[i] == y.r - y.l + 1) u.suf[i] = y.suf[i] + x.suf[i];
*/
}
}
void pushup(int u) {pushup(tr[u], tr[lson], tr[rson]);}
void build(int u, int l, int r)
{
tr[u].l = l, tr[u].r = r, tr[u].lazy = -1;//一开始懒标记要标记成-1,表示啥也不干
if (l == r)
{
tr[u].sum[a[l]] = tr[u].ma[a[l]] = tr[u].pre[a[l]] = tr[u].suf[a[l]] = 1;
return ;
}
int mid = l + r >> 1;
build(lson, l, mid), build(rson, mid + 1, r);
pushup(u);
}
//将u节点所代表的区间内所有数全变v
void change1(int u, int v)
{
tr[u].sum[v] = tr[u].ma[v] = tr[u].pre[v] = tr[u].suf[v] = tr[u].r - tr[u].l + 1;
tr[u].sum[v ^ 1] = tr[u].ma[v ^ 1] = tr[u].pre[v ^ 1] = tr[u].suf[v ^ 1] = 0;
tr[u].lazy = v;
}
//将u节点所代表的区间内所有数全部取反
void change2(int u)
{
swap(tr[u].sum[0], tr[u].sum[1]);
swap(tr[u].ma[0], tr[u].ma[1]);
swap(tr[u].pre[0], tr[u].pre[1]);
swap(tr[u].suf[0], tr[u].suf[1]);
if (tr[u].lazy == -1) tr[u].lazy = 2;//什么都没标记的话,标记取反
else if (tr[u].lazy == 2) tr[u].lazy = -1;//标记取反的话,两个取反就抵消了
else tr[u].lazy ^= 1;//否则0变1,1变0
}
void pushdown(int u)
{
if (tr[u].lazy == -1) return ;//啥也不干
else if (tr[u].lazy == 2)
{
change2(lson), change2(rson);
tr[u].lazy = -1;//清空懒标记
}
else//全变1或者全变0
{
change1(lson, tr[u].lazy), change1(rson, tr[u].lazy);
tr[u].lazy = -1;//清空懒标记
}
}
//将[x,y]区间内所有数变为v
void modify(int u, int x, int y, int v)
{
if (x <= tr[u].l && tr[u].r <= y)
{
change1(u, v);
return ;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(lson, x, y, v);
if (y > mid) modify(rson, x, y, v);
pushup(u);
}
//将[x,y]区间内所有数全部取反
void rev(int u, int x, int y)
{
if (x <= tr[u].l && tr[u].r <= y)
{
change2(u);
return ;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) rev(lson, x, y);
if (y > mid) rev(rson, x, y);
pushup(u);
}
Node query(int u, int x, int y)
{
if (x <= tr[u].l && tr[u].r <= y) return tr[u];
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (y <= mid) return query(lson, x, y);
else if (x > mid) return query(rson, x, y);
else
{
Node l, r, res;
l = query(lson, x, y), r = query(rson, x, y);
pushup(res, l, r);//这里的res是空的,所以pushup里面mid不能再像往常一样定义
return res;
}
}
int main()
{
read(n), read(m);
for (int i = 1; i <= n; i ++) read(a[i]);
build(1, 1, n);
for (int i = 1; i <= m; i ++)
{
int op, a, b;
read(op), read(a), read(b);
a ++, b ++;//题目下标从0开始,所以全部+1
if (op == 0) modify(1, a, b, 0);
else if (op == 1) modify(1, a, b, 1);
else if (op == 2) rev(1, a, b);
else if (op == 3) printf("%d\n", query(1, a, b).sum[1]);
else printf("%d\n", query(1, a, b).ma[1]);
}
return 0;
}