题目链接
题目描述
给定一个长度为 的静态数组
。有
次查询,每次查询给定一个区间
,要求回答该区间的最大值或最小值。
解题思路
本题是静态区间最值查询(Range Minimum/Maximum Query, RMQ)的模板题。由于数组内容不会改变,且查询次数较多,我们需要一个比每次都遍历区间更高效的数据结构。
解决此问题的经典高效算法是 稀疏表(Sparse Table, ST)。该算法的核心思想是利用动态规划和倍增的思想,预处理出所有长度为 的区间的最大值和最小值,从而实现
的查询。
-
预处理 (
)
- 我们创建两个二维数组,
st_min[p][i]
和st_max[p][i]
。 st_min[p][i]
存储从下标开始,长度为
的区间
[i, i + 2^p - 1]
内的最小值。st_max
同理。- 基础状态 (
): 区间长度为
,所以
st_min[0][i] = st_max[0][i] = a[i]
。 - 递推关系: 对于
,一个长度为
的区间可以看作是两个长度为
的重叠子区间的并集。因此:
st_min[p][i] = min(st_min[p-1][i], st_min[p-1][i + (1 << (p-1))])
st_max[p][i] = max(st_max[p-1][i], st_max[p-1][i + (1 << (p-1))])
- 为了在查询时快速确定
的值,我们还需要预处理一个对数表
log_table
,其中log_table[i]
存储。
- 我们创建两个二维数组,
-
查询 (
)
- 对于任意查询区间
[l, r]
,我们首先计算其长度len = r - l + 1
。 - 然后,利用对数表找到最大的整数
使得
,即
。
- 查询区间
[l, r]
可以被两个长度为的区间
[l, l + 2^p - 1]
和[r - 2^p + 1, r]
完全覆盖。 - 由于
min
和max
运算是幂等的(即op(x, x) = x
),两个区间重叠的部分不会影响最终结果。因此:query_min(l, r) = min(st_min[p][l], st_min[p][r - (1 << p) + 1])
query_max(l, r) = max(st_max[p][l], st_max[p][r - (1 << p) + 1])
- 对于任意查询区间
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;
const int MAXN = 500005;
const int LOGN = 20;
int a[MAXN];
int st_min[LOGN][MAXN];
int st_max[LOGN][MAXN];
int log_table[MAXN];
void build_st(int n) {
log_table[1] = 0;
for (int i = 2; i <= n; i++) {
log_table[i] = log_table[i / 2] + 1;
}
for (int i = 0; i < n; i++) {
st_min[0][i] = a[i];
st_max[0][i] = a[i];
}
for (int p = 1; p < LOGN; p++) {
for (int i = 0; i + (1 << p) <= n; i++) {
st_min[p][i] = min(st_min[p - 1][i], st_min[p - 1][i + (1 << (p - 1))]);
st_max[p][i] = max(st_max[p - 1][i], st_max[p - 1][i + (1 << (p - 1))]);
}
}
}
int query_min(int l, int r) {
int p = log_table[r - l + 1];
return min(st_min[p][l], st_min[p][r - (1 << p) + 1]);
}
int query_max(int l, int r) {
int p = log_table[r - l + 1];
return max(st_max[p][l], st_max[p][r - (1 << p) + 1]);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int n, m;
cin >> n >> m;
for (int i = 0; i < n; i++) {
cin >> a[i];
}
build_st(n);
for (int i = 0; i < m; i++) {
int type, l, r;
cin >> type >> l >> r;
l--;
r--;
if (type == 1) {
cout << query_min(l, r) << "\n";
} else {
cout << query_max(l, r) << "\n";
}
}
return 0;
}
import java.util.Scanner;
public class Main {
static final int MAXN = 500005;
static final int LOGN = 20;
static int[] a = new int[MAXN];
static int[][] stMin = new int[LOGN][MAXN];
static int[][] stMax = new int[LOGN][MAXN];
static int[] logTable = new int[MAXN];
static void buildSt(int n) {
logTable[1] = 0;
for (int i = 2; i <= n; i++) {
logTable[i] = logTable[i / 2] + 1;
}
for (int i = 0; i < n; i++) {
stMin[0][i] = a[i];
stMax[0][i] = a[i];
}
for (int p = 1; p < LOGN; p++) {
for (int i = 0; i + (1 << p) <= n; i++) {
stMin[p][i] = Math.min(stMin[p - 1][i], stMin[p - 1][i + (1 << (p - 1))]);
stMax[p][i] = Math.max(stMax[p - 1][i], stMax[p - 1][i + (1 << (p - 1))]);
}
}
}
static int queryMin(int l, int r) {
int p = logTable[r - l + 1];
return Math.min(stMin[p][l], stMin[p][r - (1 << p) + 1]);
}
static int queryMax(int l, int r) {
int p = logTable[r - l + 1];
return Math.max(stMax[p][l], stMax[p][r - (1 << p) + 1]);
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
for (int i = 0; i < n; i++) {
a[i] = sc.nextInt();
}
buildSt(n);
for (int i = 0; i < m; i++) {
int type = sc.nextInt();
int l = sc.nextInt();
int r = sc.nextInt();
l--;
r--;
if (type == 1) {
System.out.println(queryMin(l, r));
} else {
System.out.println(queryMax(l, r));
}
}
}
}
import math
def solve():
n, m = map(int, input().split())
a = list(map(int, input().split()))
LOGN = (n).bit_length()
st_min = [[0] * n for _ in range(LOGN)]
st_max = [[0] * n for _ in range(LOGN)]
for i in range(n):
st_min[0][i] = a[i]
st_max[0][i] = a[i]
for p in range(1, LOGN):
for i in range(n - (1 << p) + 1):
st_min[p][i] = min(st_min[p-1][i], st_min[p-1][i + (1 << (p-1))])
st_max[p][i] = max(st_max[p-1][i], st_max[p-1][i + (1 << (p-1))])
log_table = [0] * (n + 1)
for i in range(2, n + 1):
log_table[i] = log_table[i // 2] + 1
results = []
for _ in range(m):
line = list(map(int, input().split()))
op_type, l, r = line[0], line[1], line[2]
l -= 1
r -= 1
p = log_table[r - l + 1]
if op_type == 1:
res = min(st_min[p][l], st_min[p][r - (1 << p) + 1])
results.append(str(res))
else:
res = max(st_max[p][l], st_max[p][r - (1 << p) + 1])
results.append(str(res))
print("\n".join(results))
solve()
算法及复杂度
- 算法:稀疏表 (Sparse Table, ST)
- 时间复杂度:
。预处理阶段需要
,后续的
次查询每次需要
。
- 空间复杂度:
,主要用于存储稀疏表。