给定长度为数组,执行两种操作:
  1. 求解整个序列的第小的值
  2. 这个位置上的元素修改为,即
每次操作时,输入一个整数p(0\leq p\leq n),如果p = 0,那么执行操作1;如果p \geq 1,则再输入一个整数,然后将的值修改为

传统线段树是维护数组下标区间的信息,例如:
  • 区间和:
  • 区间最大值:
而权值线段树(也称为值域线段树)维护数值的值域区间,例如:
  • 数值在范围里面出现的次数
  • 整个集合中第小的数

在权值线段树中,每一个叶子节点,那么就是该数字,而表示出现的次数。
所以只需要维护次数即可:
因为每一个节点的就是该数字,一共需要个叶子节点,那么需要开,其中也就是说范围是与数值的最大值有关系的
const int N = 1e6 + 10;

struct Node{
    int l, r;
    int cnt;
}tr[4 * N];
正常的线段树函数:
没有函数,一开始所有的值的次数都为,加了函数也没有用
void build(int u, int l, int r){
    if(l == r) tr[u] = {l, r, 0};
    else{
        tr[u] = {l, r, 0};
        int mid = (tr[u].l + tr[u].r) >> 1;
        build(u << 1, l, mid);
        build(u << 1, mid + 1, r);
    }
}
然后就是将数组中的每一个添加进去,就是数值的次数增加:
for(int i = 1; i <= n; i ++){
    modify(1, a[i], 1);
}

void modify(int u, int x, int v){
    if(tr[u].l == tr[u].r){
        tr[u].cnt += v;
                return ;
    }
    else{
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(x <= mid) modify(u << 1, x, v);
        else modify(u << 1 | 1, x, v);
        push_up(u);
    }
}
操作一:求解整个序列的第小的值
肯定不能暴力循环查找,一个区间分为左右两个子区间,如果左子区间的总数,那就说明小的值在左子区间,如果cnt >k,那么就不需要查找左子区间,只需要查找有子区间的第个数
int query(int u, int k){
    if(tr[u].l == tr[u].r) return tr[u].l
    else{
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(k <= tr[u << 1].cnt) return query(u << 1, k);
        else return query(u << 1 | 1, k - tr[u << 1].cnt);
    }
}
操作二:这个位置上的元素修改为,即
这个相当于两个函数,将次数减去,将的值加上
肥肠容易忘记的一步是 在原来数组将修改为

总代码:
#include<bits/stdc++.h>
using namespace std;

#define endl '\n'
#define int long long
#define IOS ios::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#define HelloWorld IOS;


const int N = 1e6 + 10;

struct Node{
    int l, r;
    int cnt;
}tr[4 * N];

int n, q, a[N];

void push_up(int u){
    tr[u].cnt = tr[u << 1].cnt + tr[u << 1 | 1].cnt;
}

void build(int u, int l, int r){
    if(l == r) tr[u] = {l, r, 0};
    else{
        tr[u] = {l, r, 0};
        int mid = (tr[u].l + tr[u].r) >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
    }
}

void modify(int u, int x, int v){
    if(tr[u].l == tr[u].r){
        tr[u].cnt += v;
        return ;
    }
    else{
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(x <= mid) modify(u << 1, x, v);
        else modify(u << 1 | 1, x, v);
        push_up(u);
    }
}

int query(int u, int k){
    if(tr[u].l == tr[u].r) return tr[u].l;
    else{
        if(k <= tr[u << 1].cnt) return query(u << 1, k);
        else return query(u << 1 | 1, k - tr[u << 1].cnt);
    }
}

signed main(){
    HelloWorld;
    
    cin >> n >> q;
    for(int i = 1; i <= n; i ++) cin >> a[i];
    build(1, 1, 1000000);
    for(int i = 1; i <= n; i ++) modify(1, a[i], 1);
    while(q --){
        int op; cin >> op;
        if(op == 0){
            int k; cin >> k;
            cout << query(1, k) << endl;
        }
        else{
            int x; cin >> x;
            modify(1, a[op], -1);
            modify(1, x, 1);
            a[op] = x;
        }
    }
    return 0;
}