题目链接

PEEK61 【模板】静态区间最值

题目描述

给定一个长度为 的静态数组 。有 次查询,每次查询给定一个区间 ,要求回答该区间的最大值或最小值。

解题思路

本题是静态区间最值查询(Range Minimum/Maximum Query, RMQ)的模板题。由于数组内容不会改变,且查询次数较多,我们需要一个比每次都遍历区间更高效的数据结构。

解决此问题的经典高效算法是 稀疏表(Sparse Table, ST)。该算法的核心思想是利用动态规划和倍增的思想,预处理出所有长度为 的区间的最大值和最小值,从而实现 的查询。

  1. 预处理 ()

    • 我们创建两个二维数组,st_min[p][i]st_max[p][i]
    • st_min[p][i] 存储从下标 开始,长度为 的区间 [i, i + 2^p - 1] 内的最小值。st_max 同理。
    • 基础状态 (): 区间长度为 ,所以 st_min[0][i] = st_max[0][i] = a[i]
    • 递推关系: 对于 ,一个长度为 的区间可以看作是两个长度为 的重叠子区间的并集。因此:
      • st_min[p][i] = min(st_min[p-1][i], st_min[p-1][i + (1 << (p-1))])
      • st_max[p][i] = max(st_max[p-1][i], st_max[p-1][i + (1 << (p-1))])
    • 为了在查询时快速确定 的值,我们还需要预处理一个对数表 log_table,其中 log_table[i] 存储
  2. 查询 ()

    • 对于任意查询区间 [l, r],我们首先计算其长度 len = r - l + 1
    • 然后,利用对数表找到最大的整数 使得 ,即
    • 查询区间 [l, r] 可以被两个长度为 的区间 [l, l + 2^p - 1][r - 2^p + 1, r] 完全覆盖。
    • 由于 minmax 运算是幂等的(即 op(x, x) = x),两个区间重叠的部分不会影响最终结果。因此:
      • query_min(l, r) = min(st_min[p][l], st_min[p][r - (1 << p) + 1])
      • query_max(l, r) = max(st_max[p][l], st_max[p][r - (1 << p) + 1])

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>

using namespace std;

const int MAXN = 500005;
const int LOGN = 20;

int a[MAXN];
int st_min[LOGN][MAXN];
int st_max[LOGN][MAXN];
int log_table[MAXN];

void build_st(int n) {
    log_table[1] = 0;
    for (int i = 2; i <= n; i++) {
        log_table[i] = log_table[i / 2] + 1;
    }

    for (int i = 0; i < n; i++) {
        st_min[0][i] = a[i];
        st_max[0][i] = a[i];
    }

    for (int p = 1; p < LOGN; p++) {
        for (int i = 0; i + (1 << p) <= n; i++) {
            st_min[p][i] = min(st_min[p - 1][i], st_min[p - 1][i + (1 << (p - 1))]);
            st_max[p][i] = max(st_max[p - 1][i], st_max[p - 1][i + (1 << (p - 1))]);
        }
    }
}

int query_min(int l, int r) {
    int p = log_table[r - l + 1];
    return min(st_min[p][l], st_min[p][r - (1 << p) + 1]);
}

int query_max(int l, int r) {
    int p = log_table[r - l + 1];
    return max(st_max[p][l], st_max[p][r - (1 << p) + 1]);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }

    build_st(n);

    for (int i = 0; i < m; i++) {
        int type, l, r;
        cin >> type >> l >> r;
        l--; 
        r--;
        if (type == 1) {
            cout << query_min(l, r) << "\n";
        } else {
            cout << query_max(l, r) << "\n";
        }
    }

    return 0;
}
import java.util.Scanner;

public class Main {
    static final int MAXN = 500005;
    static final int LOGN = 20;

    static int[] a = new int[MAXN];
    static int[][] stMin = new int[LOGN][MAXN];
    static int[][] stMax = new int[LOGN][MAXN];
    static int[] logTable = new int[MAXN];

    static void buildSt(int n) {
        logTable[1] = 0;
        for (int i = 2; i <= n; i++) {
            logTable[i] = logTable[i / 2] + 1;
        }

        for (int i = 0; i < n; i++) {
            stMin[0][i] = a[i];
            stMax[0][i] = a[i];
        }

        for (int p = 1; p < LOGN; p++) {
            for (int i = 0; i + (1 << p) <= n; i++) {
                stMin[p][i] = Math.min(stMin[p - 1][i], stMin[p - 1][i + (1 << (p - 1))]);
                stMax[p][i] = Math.max(stMax[p - 1][i], stMax[p - 1][i + (1 << (p - 1))]);
            }
        }
    }

    static int queryMin(int l, int r) {
        int p = logTable[r - l + 1];
        return Math.min(stMin[p][l], stMin[p][r - (1 << p) + 1]);
    }

    static int queryMax(int l, int r) {
        int p = logTable[r - l + 1];
        return Math.max(stMax[p][l], stMax[p][r - (1 << p) + 1]);
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextInt();
        }

        buildSt(n);

        for (int i = 0; i < m; i++) {
            int type = sc.nextInt();
            int l = sc.nextInt();
            int r = sc.nextInt();
            l--;
            r--;
            if (type == 1) {
                System.out.println(queryMin(l, r));
            } else {
                System.out.println(queryMax(l, r));
            }
        }
    }
}
import math

def solve():
    n, m = map(int, input().split())
    a = list(map(int, input().split()))

    LOGN = (n).bit_length()
    
    st_min = [[0] * n for _ in range(LOGN)]
    st_max = [[0] * n for _ in range(LOGN)]

    for i in range(n):
        st_min[0][i] = a[i]
        st_max[0][i] = a[i]

    for p in range(1, LOGN):
        for i in range(n - (1 << p) + 1):
            st_min[p][i] = min(st_min[p-1][i], st_min[p-1][i + (1 << (p-1))])
            st_max[p][i] = max(st_max[p-1][i], st_max[p-1][i + (1 << (p-1))])

    log_table = [0] * (n + 1)
    for i in range(2, n + 1):
        log_table[i] = log_table[i // 2] + 1
        
    results = []
    for _ in range(m):
        line = list(map(int, input().split()))
        op_type, l, r = line[0], line[1], line[2]
        l -= 1
        r -= 1
        
        p = log_table[r - l + 1]
        
        if op_type == 1:
            res = min(st_min[p][l], st_min[p][r - (1 << p) + 1])
            results.append(str(res))
        else:
            res = max(st_max[p][l], st_max[p][r - (1 << p) + 1])
            results.append(str(res))
            
    print("\n".join(results))

solve()

算法及复杂度

  • 算法:稀疏表 (Sparse Table, ST)
  • 时间复杂度。预处理阶段需要 ,后续的 次查询每次需要
  • 空间复杂度,主要用于存储稀疏表。