题目链接
题目描述
给定一个长度为 的数组,你需要构建一个数据结构来支持以下两种操作:
- 查询第
小:找出当前整个序列中,从小到大排序后的第
个值。
- 单点修改:将第
个位置的元素修改为
。 注意:本题的值域范围为
。
解题思路
本题要求我们在支持单点修改的同时,快速查询整个序列的第 小值。
由于本题的值域被限制在了 ,我们可以采用一种非常直接高效的方法,而无需进行离散化。
核心思想:值域上的树状数组
-
数据结构:
- 我们可以建立一个大小为值域上限(例如
)的树状数组。
- 在这个树状数组中,下标代表的是具体的数值。
tr[v]
存储的是当前序列中值为v
的数有多少个。
- 我们可以建立一个大小为值域上限(例如
-
操作转化:
- 初始化:遍历初始数组,对于每个数
a[i]
,我们直接在树状数组的a[i]
位置上加一,即add(a[i], 1)
。 - 单点修改
a[p] = x
:这个操作非常直观。我们只需要将旧值a[p]
的计数减一,并将新值x
的计数加一。这对应树状数组的两次单点更新:add(a[p], -1)
和add(x, 1)
。 - 查询第
小:这个问题等价于,在树状数组中找到一个最小的下标
v
,使得所有下标的数的总个数(即前缀和
query(v)
)第一次达到或超过。
- 这个问题可以通过在树状数组上进行二分查找(或称倍增)来高效解决。我们从二进制的高位向低位开始,尝试一步步确定最终的值。通过检查每个二进制位对应的区间,判断第
k
个数是否落在其中,从而逐步逼近答案。这个过程的复杂度为,其中
是值域上限。
- 这个问题可以通过在树状数组上进行二分查找(或称倍增)来高效解决。我们从二进制的高位向低位开始,尝试一步步确定最终的值。通过检查每个二进制位对应的区间,判断第
- 初始化:遍历初始数组,对于每个数
由于我们是直接在值域上操作,所有操作都可以在线完成,代码逻辑也大大简化。
代码
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1000100;
int a[N], tr[N];
void add(int now, int val) {
for (int i = now; i <= N; i += i & -i)
tr[i] += val;
}
int kth(int k) {
int ans = 0, res = 0;
for (int i = 1 << __lg(N); i > 0; i >>= 1) {
ans += i;
if (ans < N && res + tr[ans] < k)
res += tr[ans];
else
ans -= i;
}
return ans + 1;
}
int main(void) {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n, q, i, p, x;
cin >> n >> q;
for (i = 1; i <= n; ++i) {
cin >> a[i];
add(a[i], 1);
}
while (q--) {
cin >> p >> x;
if (p == 0)
cout << kth(x) << '\n';
else {
add(a[p], -1);
a[p] = x;
add(a[p], 1);
}
}
return 0;
}
import java.io.*;
import java.util.*;
public class Main {
static final int MAX_VAL = 1000100;
static int n;
static long[] tr = new long[MAX_VAL];
static int[] a;
static void add(int now, int val) {
for (int i = now; i < MAX_VAL; i += i & -i) {
tr[i] += val;
}
}
static int kth(int k) {
int ans = 0;
long res = 0;
// log2(1000100) is approx 19.9, so 20 is a safe start
for (int i = 1 << 19; i > 0; i >>= 1) {
int next_ans = ans + i;
if (next_ans < MAX_VAL && res + tr[next_ans] < k) {
res += tr[next_ans];
ans = next_ans;
}
}
return ans + 1;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
n = Integer.parseInt(st.nextToken());
int q = Integer.parseInt(st.nextToken());
a = new int[n + 1];
st = new StringTokenizer(br.readLine());
for (int i = 1; i <= n; i++) {
a[i] = Integer.parseInt(st.nextToken());
add(a[i], 1);
}
StringBuilder sb = new StringBuilder();
for (int i = 0; i < q; i++) {
st = new StringTokenizer(br.readLine());
int p = Integer.parseInt(st.nextToken());
int x = Integer.parseInt(st.nextToken());
if (p == 0) {
sb.append(kth(x)).append("\n");
} else {
add(a[p], -1);
a[p] = x;
add(a[p], 1);
}
}
System.out.print(sb.toString());
}
}
import sys
MAX_VAL = 1000100
tr = [0] * MAX_VAL
def add(now, val):
i = now
while i < MAX_VAL:
tr[i] += val
i += i & -i
def kth(k):
ans = 0
res = 0
# log2(1000100) is approx 19.9, starting from 1<<19 is efficient
i = 1 << 19
while i > 0:
next_ans = ans + i
if next_ans < MAX_VAL and res + tr[next_ans] < k:
res += tr[next_ans]
ans = next_ans
i >>= 1
return ans + 1
def main():
input = sys.stdin.readline
n, q = map(int, input().split())
a = [0] + list(map(int, input().split()))
for i in range(1, n + 1):
add(a[i], 1)
for _ in range(q):
p, x = map(int, input().split())
if p == 0:
sys.stdout.write(str(kth(x)) + '\n')
else:
add(a[p], -1)
a[p] = x
add(a[p], 1)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:树状数组上倍增
- 时间复杂度:
,其中
是值域上限(本题中约为
)。初始化需要
,每次操作(修改或查询)的复杂度为
。
- 空间复杂度:
,需要存储原始数组和大小为值域的树状数组。