题解:BISHI122 区间后缀极大位置计数

题目链接

区间后缀极大位置计数

题目描述

给定长度为 的数列 ,固定子区间长度 。对每个长度为 的子区间 ),定义“后缀极大”位置个数:下标 满足对所有 的个数。按 的顺序输出每个子区间的后缀极大位置个数。

解题思路

等价判定: 在窗口 中为“后缀极大”当且仅当在 内不存在严格大于 的元素。设 的严格下一个更大元素位置(若不存在则为 ),则在以 结尾的窗口中, 被计入当且仅当

做法:

  • 先用单调栈在线预处理每个位置的 (严格大于,遇到 才弹)。
  • 随着 扫描,维护“当前活跃集合”
  • 用树状数组(BIT)在下标轴上维护活跃标记:
    • 到达 时将位置 置为活跃(加 ),
    • 同时把所有 的位置 从活跃中移除(加 )。
  • 对于固定 ,窗口 的答案为活跃个数在区间 的和,即 (当 时输出)。

每个位置仅被加入/移除一次,总复杂度 ,空间

代码

#include <bits/stdc++.h>
using namespace std;

struct BIT {
    int n; vector<int> t;
    BIT(int n=0): n(n), t(n+1,0) {}
    void add(int i, int v){ for(; i<=n; i+=i&-i) t[i]+=v; }
    int sum(int i){ int s=0; for(; i>0; i-=i&-i) s+=t[i]; return s; }
};

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, k; if(!(cin>>n>>k)) return 0;
    vector<long long> a(n+1);
    for(int i=1;i<=n;i++) cin>>a[i];

    // next strictly greater element to the right
    vector<int> nge(n+1, n+1);
    vector<int> st; st.reserve(n);
    for(int i=1;i<=n;i++){
        while(!st.empty() && a[st.back()] < a[i]){ nge[st.back()] = i; st.pop_back(); }
        st.push_back(i);
    }

    // bucket positions by their nge value
    vector<vector<int>> byR(n+2);
    for(int i=1;i<=n;i++) if(nge[i]<=n) byR[nge[i]].push_back(i);

    BIT bit(n);
    vector<int> ans;
    ans.reserve(max(0, n-k+1));
    for(int R=1; R<=n; R++){
        bit.add(R, 1);                 // position R becomes active
        for(int i: byR[R]) bit.add(i, -1); // positions with nge=i==R become inactive
        if(R>=k){
            int L = R - k + 1;
            int res = bit.sum(R) - bit.sum(L-1);
            ans.push_back(res);
        }
    }

    for(int x: ans) cout << x << '\n';
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static class FastScanner {
        private final InputStream in; private final byte[] buf = new byte[1<<16];
        private int p=0,l=0; FastScanner(InputStream is){in=is;}
        private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
        int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
        long nextLong() throws IOException { int c; long s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
    }

    static class BIT {
        int n; int[] t;
        BIT(int n){ this.n=n; t=new int[n+1]; }
        void add(int i,int v){ for(; i<=n; i+=i&-i) t[i]+=v; }
        int sum(int i){ int s=0; for(; i>0; i-=i&-i) s+=t[i]; return s; }
    }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        int k = fs.nextInt();
        long[] a = new long[n+1];
        for(int i=1;i<=n;i++) a[i]=fs.nextLong();

        int[] nge = new int[n+1];
        Arrays.fill(nge, n+1);
        int[] st = new int[n]; int top=0;
        for(int i=1;i<=n;i++){
            while(top>0 && a[st[top-1]] < a[i]){ nge[st[--top]] = i; }
            st[top++] = i;
        }

        ArrayList<Integer>[] byR = new ArrayList[n+2];
        for(int i=0;i<byR.length;i++) byR[i]=new ArrayList<>();
        for(int i=1;i<=n;i++) if(nge[i]<=n) byR[nge[i]].add(i);

        BIT bit = new BIT(n);
        StringBuilder out = new StringBuilder();
        for(int R=1; R<=n; R++){
            bit.add(R, 1);
            for(int i: byR[R]) bit.add(i, -1);
            if(R>=k){
                int L = R - k + 1;
                int res = bit.sum(R) - bit.sum(L-1);
                out.append(res).append('\n');
            }
        }
        System.out.print(out.toString());
    }
}
import sys

data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); k = int(next(it))
a = [0]*(n+1)
for i in range(1, n+1):
    a[i] = int(next(it))

# next strictly greater to the right
nge = [n+1]*(n+1)
st = []
for i in range(1, n+1):
    while st and a[st[-1]] < a[i]:
        nge[st.pop()] = i
    st.append(i)

byR = [[] for _ in range(n+2)]
for i in range(1, n+1):
    if nge[i] <= n:
        byR[nge[i]].append(i)

bit = [0]*(n+1)
def add(i, v):
    while i <= n:
        bit[i] += v
        i += i & -i

def sum_(i):
    s = 0
    while i > 0:
        s += bit[i]
        i -= i & -i
    return s

out = []
for R in range(1, n+1):
    add(R, 1)
    for idx in byR[R]:
        add(idx, -1)
    if R >= k:
        L = R - k + 1
        res = sum_(R) - sum_(L-1)
        out.append(str(res))

sys.stdout.write('\n'.join(out))

算法及复杂度

  • 算法:单调栈预处理 + 树状数组维护活跃下标
  • 时间复杂度:
  • 空间复杂度: