题目链接

动态序列

题目描述

给定一个长度为 的数组,你需要构建一个数据结构来支持以下两种操作:

  1. 查询第 :找出当前整个序列中,从小到大排序后的第 个值。
  2. 单点修改:将第 个位置的元素修改为 注意:本题的值域范围为

解题思路

本题要求我们在支持单点修改的同时,快速查询整个序列的第 小值。

由于本题的值域被限制在了 ,我们可以采用一种非常直接高效的方法,而无需进行离散化

核心思想:值域上的树状数组

  1. 数据结构

    • 我们可以建立一个大小为值域上限(例如 )的树状数组。
    • 在这个树状数组中,下标代表的是具体的数值tr[v] 存储的是当前序列中值为 v 的数有多少个。
  2. 操作转化

    • 初始化:遍历初始数组,对于每个数 a[i],我们直接在树状数组的 a[i] 位置上加一,即 add(a[i], 1)
    • 单点修改 a[p] = x:这个操作非常直观。我们只需要将旧值 a[p] 的计数减一,并将新值 x 的计数加一。这对应树状数组的两次单点更新:add(a[p], -1)add(x, 1)
    • 查询第 :这个问题等价于,在树状数组中找到一个最小的下标 v,使得所有下标 的数的总个数(即前缀和 query(v))第一次达到或超过
      • 这个问题可以通过在树状数组上进行二分查找(或称倍增)来高效解决。我们从二进制的高位向低位开始,尝试一步步确定最终的值。通过检查每个二进制位对应的区间,判断第 k 个数是否落在其中,从而逐步逼近答案。这个过程的复杂度为 ,其中 是值域上限。

由于我们是直接在值域上操作,所有操作都可以在线完成,代码逻辑也大大简化。

代码

#include <iostream>
#include <algorithm>

using namespace std;

const int N = 1000100;
int a[N], tr[N];

void add(int now, int val) {
    for (int i = now; i <= N; i += i & -i)
        tr[i] += val;
}

int kth(int k) {
    int ans = 0, res = 0;
    for (int i = 1 << __lg(N); i > 0; i >>= 1) {
        ans += i;
        if (ans < N && res + tr[ans] < k)
            res += tr[ans];
        else
            ans -= i;
    }
    return ans + 1;
}

int main(void) {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int n, q, i, p, x;
    cin >> n >> q;
    for (i = 1; i <= n; ++i) {
        cin >> a[i];
        add(a[i], 1);
    }
    while (q--) {
        cin >> p >> x;
        if (p == 0)
            cout << kth(x) << '\n';
        else {
            add(a[p], -1);
            a[p] = x;
            add(a[p], 1);
        }
    }
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static final int MAX_VAL = 1000100;
    static int n;
    static long[] tr = new long[MAX_VAL];
    static int[] a;

    static void add(int now, int val) {
        for (int i = now; i < MAX_VAL; i += i & -i) {
            tr[i] += val;
        }
    }

    static int kth(int k) {
        int ans = 0;
        long res = 0;
        // log2(1000100) is approx 19.9, so 20 is a safe start
        for (int i = 1 << 19; i > 0; i >>= 1) {
            int next_ans = ans + i;
            if (next_ans < MAX_VAL && res + tr[next_ans] < k) {
                res += tr[next_ans];
                ans = next_ans;
            }
        }
        return ans + 1;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        int q = Integer.parseInt(st.nextToken());

        a = new int[n + 1];
        st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; i++) {
            a[i] = Integer.parseInt(st.nextToken());
            add(a[i], 1);
        }

        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < q; i++) {
            st = new StringTokenizer(br.readLine());
            int p = Integer.parseInt(st.nextToken());
            int x = Integer.parseInt(st.nextToken());
            if (p == 0) {
                sb.append(kth(x)).append("\n");
            } else {
                add(a[p], -1);
                a[p] = x;
                add(a[p], 1);
            }
        }
        System.out.print(sb.toString());
    }
}
import sys

MAX_VAL = 1000100
tr = [0] * MAX_VAL

def add(now, val):
    i = now
    while i < MAX_VAL:
        tr[i] += val
        i += i & -i

def kth(k):
    ans = 0
    res = 0
    # log2(1000100) is approx 19.9, starting from 1<<19 is efficient
    i = 1 << 19 
    while i > 0:
        next_ans = ans + i
        if next_ans < MAX_VAL and res + tr[next_ans] < k:
            res += tr[next_ans]
            ans = next_ans
        i >>= 1
    return ans + 1

def main():
    input = sys.stdin.readline
    n, q = map(int, input().split())
    a = [0] + list(map(int, input().split()))

    for i in range(1, n + 1):
        add(a[i], 1)

    for _ in range(q):
        p, x = map(int, input().split())
        if p == 0:
            sys.stdout.write(str(kth(x)) + '\n')
        else:
            add(a[p], -1)
            a[p] = x
            add(a[p], 1)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:树状数组上倍增
  • 时间复杂度:,其中 是值域上限(本题中约为 )。初始化需要 ,每次操作(修改或查询)的复杂度为
  • 空间复杂度:,需要存储原始数组和大小为值域的树状数组。