题意

给定一长度为n的序列a, 有两种操作;

1: 指定序列区间[l, r] 区间加上x;

2: 指定序列区间[l, r] 输出输出区间中小于x的元素的数量;

分析

·这是典型的“区间修改区间查询”类的题目

·所以我们可以很自然地思想到使用线段树

·通过线段树维护区间的最大值

区间加

·对区间加操作,我们使用懒标记, 只更新恰好完全包含在指定区间内的区间

·否则pushdown, 将懒标记下传。

void update(int l, int r, ll num, int plc)
{
	if (l > stree[plc].r || r < stree[plc].l) return ;
	if (l <= stree[plc].l && stree[plc].r <= r)
	{
		stree[plc].maxv += num;
		stree[plc].lazy += num; // 懒标记
		return ;
	}
	pushdown(plc);
	int mid = (stree[plc].l + stree[plc].r) >> 1;
	if (mid >= l) update(l, r, num, plc << 1);
	if (mid < r) update(l, r, num, plc << 1 | 1);
	pushup(plc);
}

区间查询

·对查询操作,若当前区间恰好包含在查询区间内,且其最大值小于x(说明当前区间的元素都小于x),则返回当前区间的长度;

·否则pushdown.

·注意区间完全包含时需要特判刚好是叶子结点的情况返回,否则会导致叶子结点pushdown而溢出RE.

int query(int l, int r, ll num, int plc)
{
	if (l > stree[plc].r || r < stree[plc].l) return 0;
	if (l <= stree[plc].l && stree[plc].r <= r)
	{
		if (stree[plc].maxv < num) return stree[plc].r - stree[plc].l + 1; // 返回区间长度
		if (stree[plc].l == stree[plc].r) return 0; // 特判叶子结点
	}
		
	pushdown(plc);
	int res = 0;
	int mid = (stree[plc].l + stree[plc].r) >> 1;
	if (mid >= l) res += query(l, r, num, plc << 1);
	if (mid < r) res += query(l, r, num, plc << 1 | 1);
	return res; 
}

但是

·因为该题下的查询操作,并不同于常见的线段树区间查询(当前区间完全包含于指定区间时即可返回)

·该题中的返回条件有if (stree[plc].maxv < num)

·即仅当当前区间最大值比x小时才返回,会导致区间查询的当前区间完全包含于指定区间时仍然会往下递归

·导致区间查询的复杂度由O(logn)退化,最坏能到O(n).

·所以只维护最大值的线段树会在30个测试点里有1个爆TLE.

·所以我们考虑减枝

减枝

·显然,若当前区间的最小值都已经大于等于x了,那区间中就不会有元素满足小于x

·所以我们可以多维护一个区间最小值用来减枝

int query(int l, int r, ll num, int plc)
{
	if (l > stree[plc].r || r < stree[plc].l) return 0;
	if (stree[plc].minv >= num) return 0; // 减枝
	if (l <= stree[plc].l && stree[plc].r <= r)
	{
		if (stree[plc].maxv < num) return stree[plc].r - stree[plc].l + 1;
		if (stree[plc].l == stree[plc].r) return 0;
	}
		
	pushdown(plc);
	int res = 0;
	int mid = (stree[plc].l + stree[plc].r) >> 1;
	if (mid >= l) res += query(l, r, num, plc << 1);
	if (mid < r) res += query(l, r, num, plc << 1 | 1);
	return res; 
}

·这样就能过了(

复杂度

·线段树初始化时间复杂度 O(n)

·线段树区间加时间复杂度 O(logn)

·线段树查询时间复杂度 最坏 O(n)、 平均O(logn)

·总体时间复杂度 最坏O(qn)、 平均O(qlogn)

·空间复杂度 O(4n)

完整代码

#include <iostream>
#include <cstdio>
#include <algorithm>

using namespace std;
typedef long long ll;
const int N = 100010;

int n, q;
ll a[N];
struct STree
{
	int l = 0, r = 0;
	ll maxv, minv, lazy;
}stree[N << 2];

void pushup(int plc)
{
	stree[plc].maxv = max(stree[plc << 1].maxv, stree[plc << 1 | 1].maxv);
	stree[plc].minv = min(stree[plc << 1].minv, stree[plc << 1 | 1].minv);
}

void pushdown(int plc)
{
	if (!stree[plc].lazy) return ;
	ll lz = stree[plc].lazy;
	int lson = plc << 1, rson = plc << 1 | 1;
	
	stree[lson].maxv += lz;
	stree[lson].minv += lz;
	stree[lson].lazy += lz;
	
	stree[rson].maxv += lz;
	stree[rson].minv += lz;
	stree[rson].lazy += lz;
	
	stree[plc].lazy = 0; 
}

void init(int l, int r, int plc)
{
	if (l > r) return ;
	stree[plc].l = l, stree[plc].r = r;
	stree[plc].lazy = 0;
	if (l == r)
	{
		stree[plc].maxv = a[l];
		stree[plc].minv = a[l];
		return ;
	}
	int mid = (l + r) >> 1;
	init(l, mid, plc << 1);
	init(mid + 1, r, plc << 1 | 1);
	pushup(plc);
}

void update(int l, int r, ll num, int plc)
{
	if (l > stree[plc].r || r < stree[plc].l) return ;
	if (l <= stree[plc].l && stree[plc].r <= r)
	{
		stree[plc].maxv += num;
		stree[plc].minv += num;
		stree[plc].lazy += num;
		return ;
	}
	pushdown(plc);
	int mid = (stree[plc].l + stree[plc].r) >> 1;
	if (mid >= l) update(l, r, num, plc << 1);
	if (mid < r) update(l, r, num, plc << 1 | 1);
	pushup(plc);
}

int query(int l, int r, ll num, int plc)
{
	if (l > stree[plc].r || r < stree[plc].l) return 0;
	if (stree[plc].minv >= num) return 0;
	if (l <= stree[plc].l && stree[plc].r <= r)
	{
		if (stree[plc].maxv < num) return stree[plc].r - stree[plc].l + 1;
		if (stree[plc].l == stree[plc].r) return 0;
	}
		
	pushdown(plc);
	int res = 0;
	int mid = (stree[plc].l + stree[plc].r) >> 1;
	if (mid >= l) res += query(l, r, num, plc << 1);
	if (mid < r) res += query(l, r, num, plc << 1 | 1);
	return res; 
}

int main()
{
	scanf("%d%d", &n, &q);
	for (int i = 1; i <= n; i ++ ) scanf("%lld", &a[i]);
	
	init(1, n, 1);
	for (int i = 0; i < q; i ++ )
	{
		int op, l, r;
		ll x;
		scanf("%d%d%d%lld", &op, &l, &r, &x);
		if (op == 1) update(l, r, x, 1);
		else printf("%d\n", query(l, r, x, 1));
	}
	
	return 0;
}

希望对大家有所帮组qwq.