解题思路
用到的知识点:st表+堆
思维过程
按照题目的要求,我们要求所有长度为[l,r]的子区间中最大的k个,首先,我们不可能遍历所有的子区间,因为那是O(n^2)的,我们考虑贪心
先考虑k==1时:
当k=1,我们只需要找最大的一个子区间,但是,这也要遍历所有子区间,于是考虑如何用较小的时间找出最大的子区间,由于题目的n为500000,所以一般来说,存在O(n)或者O(nlogn)遍历子区间的方法,而且很可能是O(nlogn)
首先,通过预处理前缀和数组sum,我们可以在O(1)的时间得到一段区间[i,i+k]的和,即sum[i+k]-sum[i-1]
然后,我们发现以i为起点的所有子区间中最大的那个是[i,i+l-1] U RMQ([i+l,i+r-1])(区间[i+l,i+r-1]中最大的前缀和),用st表解决RMQ问题,一次RMQ的复杂度为O(1),而i从1循环到n-l+1就能找出所有子区间中最大的,st表预处理的复杂度是O(nlogn),问题解决了!
再考虑k>1:
首先,最大的子区间一定是上一问得到的答案,设最大的子区间是以i为起点,t为终点(i+l<=t && t<=i+r-1),那么次大的子区间无非两种情况:
1.不以i为起点的子区间
2.以i为起点,不以t为终点的子区间
那么,我们只需要在找最大的子区间时,把第一问中那些不同起点的最大子区间都入堆,同时把[i,i+l-1] U RMQ([i+l,i+t-1]) 的区间 以及 [i,i+l-1] U RMQ([i+t+1,i+r-1])入堆,然后每次取堆顶,并将堆顶分裂的子区间入堆即可。
总的复杂度:O(nlogn)
注意事项:
1.如何找到那个t(区间断点)
这个简单,只有稍微修改一下st表的代码即可:
//st表部分,用于求[i+l,i+r-1]区间内前缀和的最大值 int f[mx][19];//f[i][j]表示从i开始到i+(2^k)-1区间的最大值 int pos[mx][19];//pos[i][j]表示最大值的位置 void init(){ for(int j=1;(1<<j)<=n;++j) for(int i=1;i+(1<<j)-1<=n;++i){ if(f[i][j-1]<f[i+(1<<(j-1))][j-1]){ f[i][j]=f[i+(1<<(j-1))][j-1]; pos[i][j]=pos[i+(1<<(j-1))][j-1]; } else{ f[i][j]=f[i][j-1]; pos[i][j]=pos[i][j-1]; } } } int query(int l,int r){//查找最大值的位置 int len=int(log(r-l+1)/log(2)); return f[l][len]>f[r-(1<<len)+1][len]?pos[l][len]:pos[r-(1<<len)+1][len]; }
2.分裂区间时请注意,不要超过区间的边界,如果到了边界,就不要入堆
完整代码
#include<iostream> #include<cmath> #include<cstdio> using namespace std; const int mx=505050; inline int Read(){ int x=0,f=1; char c=getchar(); while(c>'9'||c<'0')f=c!='-'?1:-1,c=getchar(); while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar(); return x*f; } int n,k,l,r; //st表部分,用于求[i+l,i+r-1]区间内前缀和的最大值 int f[mx][19];//f[i][j]表示从i开始到i+(2^k)-1区间的最大值 int pos[mx][19];//pos[i][j]表示最大值的位置 void init(){ for(int j=1;(1<<j)<=n;++j) for(int i=1;i+(1<<j)-1<=n;++i){ if(f[i][j-1]<f[i+(1<<(j-1))][j-1]){ f[i][j]=f[i+(1<<(j-1))][j-1]; pos[i][j]=pos[i+(1<<(j-1))][j-1]; } else{ f[i][j]=f[i][j-1]; pos[i][j]=pos[i][j-1]; } } } int query(int l,int r){//查找最大值的位置 int len=int(log(r-l+1)/log(2)); return f[l][len]>f[r-(1<<len)+1][len]?pos[l][len]:pos[r-(1<<len)+1][len]; } //大根堆部分 struct Node{ int s;//从s开始连续l个 int t;//区间断点 int l;//右端点最小值 int r;//右端点最大值 int val;//区间和 bool operator > (const Node &b){ return val>b.val; } void swap(Node &a,Node &b){ Node t=a; a=b; b=t; } }; struct Heap{ int size; Node h[mx*2];//删一个点加两个点,最多2*k void push(int s,int t,int l,int r,int val){ Node a; a.s=s,a.t=t,a.l=l,a.r=r,a.val=val; push(a); } void push(Node x){ h[++size]=x; int now=size; int fa=now>>1; while(fa>=1){ if(h[now]>h[fa]){ swap(h[now],h[fa]); now=fa; fa=now>>1; } else break; } } void pop(){ swap(h[1],h[size]); --size; int now=1; int son=now<<1; while(son<=size){ if((son|1)<=size&&h[son|1]>h[son])son|=1; if(h[son]>h[now]){ swap(h[son],h[now]); now=son; son=now<<1; } else break; } } Node top(){ return h[1]; } bool empty(){ return !size; } }heap; int main(){ n=Read(),k=Read(),l=Read(),r=Read(); for(int i=1;i<=n;++i)f[i][0]=f[i-1][0]+Read(),pos[i][0]=i;//f[i][0]为前缀和数组 init(); for(int i=1;i+l-1<=n;++i){ Node a; a.s=i; a.l=l+i-1; a.r=min(r+i-1,n); a.t=query(a.l,a.r); a.val=f[a.t][0]-f[a.s-1][0]; heap.push(a); } long long ans=0; for(int i=1;i<=k;i++){ Node a=heap.top(); heap.pop(); ans+=a.val; Node b=a;//左半区 b.r=a.t-1; if(b.l<=b.r){ b.t=query(b.l,b.r); b.val=f[b.t][0]-f[b.s-1][0]; heap.push(b); } Node c=a; c.l=a.t+1; if(c.l<=c.r){ c.t=query(c.l,c.r); c.val=f[c.t][0]-f[c.s-1][0]; heap.push(c); } } printf("%lld",ans); return 0; }
我的堆是手写的,如果懒,可以用优先队列