Continuous Intervals
线段树好题呀!比赛的时候根本看不出来,赛后惊叹“学到了!”
题意:给定一个数组,求数组内有多少连续区间。“连续区间”的定义:将区间内数字按大小排序后,相邻元素差值不大于1,可以等于0。
思路:绝妙的思路!线段树+区间修改+区间最小值及最小值个数+单调栈
- 可以将“连续区间”的定义等价为:区间最大值 −区间最小值 +1 −区间不同数字的个数 =0;定义区间最大值为 max,区间最小值为 min,区间不同数字的个数为 cnt,则此表达式可写为 max−min+1−cnt=0
- 再由于线段树能维护区间最小值及最小值个数,则将上述表达式改写为 max−min−cnt>=−1,因此当我们从小到大枚举区间右端点 R时,动态的维护所有的 L所对应的 [L,R]区间的 max−min−cnt的最小值及其个数,而线段树根节点的最小值一定为 −1,因为 [R,R]区间的此值为 −1,所以根节点最小值的个数就是当前 R所对应的所有 L中满足为"连续区间"的区间个数,累加至答案即可
- 问题关键现在就转化为如何求当前 R所对应的所有 [L,R]区间的 max,min,cnt
- 其中, max,min都能用单调栈维护,前者用单调递减的栈,后者用单调递增的栈,维护栈内每个元素所能影响到的区间,进行区间更新,而当一个元素被pop出去时,再进行一次区间更新消除其影响
- 同时, cnt的求解可以用一个 map存每个数字最后出现的位置(如果没有出现过,就认为出现在 0处,当然这些细节不重要,相信能处理好),然后从最后出现位置的后一位到当前位置进行区间更新,也就是区间的所有值都 −1
- 到目前为止,对于每个 R所对应的所有 [L,R]区间的 max−min−cnt的最小值及其个数都维护好了,只需要依次累加枚举的所有 R经过更新以后的根节点的值就好了
- 要是还没懂,看代码一定能看懂,非常清晰!
题目描述
#include "bits/stdc++.h"
#define hhh printf("hhh\n")
#define see(x) (cerr<<(#x)<<'='<<(x)<<endl)
using namespace std;
typedef long long ll;
typedef pair<int,int> pr;
inline int read() {int x=0;char c=getchar();while(c<'0'||c>'9')c=getchar();while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();return x;}
const int maxn = 1e5+10;
const int mod = 1e9+7;
const double eps = 1e-9;
int n;
ll node[maxn<<2], mi[maxn<<2], lazy[maxn<<2], a[maxn], ans;
stack<int> s1, s2;
map<ll,int> pos;
void push_up(int now) {
if(!now) return;
mi[now]=min(mi[now<<1],mi[now<<1|1]);
if(mi[now<<1]==mi[now<<1|1]) node[now]=node[now<<1]+node[now<<1|1];
else if(mi[now<<1]<mi[now<<1|1]) node[now]=node[now<<1];
else node[now]=node[now<<1|1];
push_up(now>>1);
}
void push_down(int now) {
ll &d=lazy[now];
lazy[now<<1]+=d, lazy[now<<1|1]+=d;
mi[now<<1]+=d, mi[now<<1|1]+=d;
d=0;
}
void build(int l, int r, int now) {
if(l==r) { node[now]=1, mi[now]=0, lazy[now]=0; return; } //此处mi[now]也可以不初始化为0,反正只需要最小值个数,具体是多少不重要
int m=(l+r)/2; lazy[now]=0;
build(l,m,now<<1), build(m+1,r,now<<1|1);
push_up(now);
}
void update(int x, int y, ll d, int l, int r, int now) {
if(x<=l&&r<=y) { lazy[now]+=d, mi[now]+=d; return; }
if(lazy[now]) push_down(now);
int m=(l+r)/2;
if(x<=m) update(x,y,d,l,m,now<<1);
if(y>m) update(x,y,d,m+1,r,now<<1|1);
push_up(now);
}
void solve(int cas) {
printf("Case #%d: ", cas);
while(s1.size()) s1.pop();
while(s2.size()) s2.pop();
ans=0; pos.clear();
n=read(); build(1,n,1);
for(int i=1; i<=n; ++i) scanf("%lld", &a[i]);
for(int i=1; i<=n; ++i) {
int x, y;
while(s1.size()&&a[s1.top()]<a[i]) {
y=s1.top(); s1.pop();
x=s1.size()?s1.top()+1:1;
update(x,y,-a[y],1,n,1);
}
y=i;
x=s1.size()?s1.top()+1:1;
update(x,y,a[y],1,n,1);
s1.push(i);
//以上通过维护单调递增的栈维护max
while(s2.size()&&a[s2.top()]>a[i]) {
y=s2.top(); s2.pop();
x=s2.size()?s2.top()+1:1;
update(x,y,a[y],1,n,1);
}
y=i;
x=s2.size()?s2.top()+1:1;
update(x,y,-a[y],1,n,1);
s2.push(i);
//以上通过维护单调递减的栈维护min
if(pos[a[i]]) update(pos[a[i]]+1,i,-1,1,n,1);
else update(1,i,-1,1,n,1);
pos[a[i]]=i;
//以上通过map维护区间不同数字的个数
ans+=node[1]; //累加答案
}
printf("%lld\n", ans);
}
int main() {
//ios::sync_with_stdio(false); cin.tie(0);
int T=read(), t=T;
while(T--) solve(t-T);
}