算法能解决的问题
多次询问和修改答案具有二分性质, 并且单次二分时间复杂度过高的离线问题
算法原理
无论当前是添加操作还是修改操作都存储在一个集合当中, 直接二分这个集合
以计算区间第 k k k小数字为例
定义二分值域区间 [ l , r ] [l, r] [l,r], 假设操作数量是 c n t cnt cnt, 那么定义如下函数
void solve(int st, int ed, int l, int r)
传入的数值分别是 1 , c n t , m i n v , m a x v 1, cnt, minv, maxv 1,cnt,minv,maxv
核心逻辑
(1) 如果当前操作是修改操作
- 如果修改的值小于等于(因为查询的是第 k k k小数) m i d mid mid, 将该位置 + 1 +1 +1(使用树状数组或者线段树实现), 然后将当前操作分流到左缓冲区
- 如果修改的值大于 m i d mid mid, 因为计算的是第 k k k小数, 需要统计的是 ≤ m i d \le mid ≤mid的数的个数, 当前修改不对这个个数产生影响, 因此直接分流到右缓冲区
(2) 如果当前是查询操作
- 以树状数组为例, 在区间 [ l , r ] [l, r] [l,r]之间小于等于 m i d mid mid的数的个数是 s = S r − S l − 1 s = S_{r} - S_{l - 1} s=Sr−Sl−1
- 如果当前查询的排名 s < k s < k s<k, 说明在由子区间(因为数字不够), 将查询操作分流到右缓冲区, 同时将 k − s k - s k−s, 因为在右缓冲区的排名需要减去当前 [ l , r ] [l, r] [l,r]这些数字
- 如果当前查询的排名 s ≥ k s \ge k s≥k, 说明数字数量够排名, 那么将当前查询操作分流到左缓冲区
(3) 在每一个递归层级需要将树状数组或者线段树清空
(4) 将划分好的左右缓冲区的操作放回到操作序列中, 方便下一层递归层级计算
(5) 分治处理左右缓冲区
算法步骤
- 创建操作结构体节点, 记录查询修改等等操作
- 创建树状数组/线段树等节点, 辅助分治算法统计对于当前值 m i d mid mid的序列的属性(例如小于等于 m i d mid mid的数字数量)
- 实现整体二分核心逻辑, 主要实现分流操作
- 递归分治
- 输出结果
核心逻辑代码实现
void solve(int st, int ed, int l, int r) {
if (st > ed) return;
// 二分找到了答案, 统计[st, ed]的答案
if (l == r) {
for (int i = st; i <= ed; ++i) {
if (op[i].op) ans[op[i].id] = l;
}
return;
}
// 对值域进行二分
int mid = l + r >> 1;
int idx1 = 0, idx2 = 0;
for (int i = st; i <= ed; ++i) {
// 当前是插入或者修改操作
if (!op[i].op) {
// 只统计<= mid的数的个数, 因为求得是区间第k小数
if (op[i].val <= mid) {
add(op[i].pos, 1);
q1[++idx1] = op[i];
}
// 大于mid的数对答案不产生影响
else q2[++idx2] = op[i];
}
else {
// 统计区间内<=mid的数字个数
int s = get(op[i].r) - get(op[i].l - 1);
// 当前区间内的数的个数不够k个, 答案在右子区间, 在右区间的排名应该是k - s
if (s < op[i].k) {
op[i].k -= s;
q2[++idx2] = op[i];
}
// 当前区间的数的个数够k个, 答案在左子区间
else q1[++idx1] = op[i];
}
}
// 清空树状数组
for (int i = 1; i <= idx1; ++i) {
if (!q1[i].op) add(q1[i].pos, -1);
}
// 将缓冲区结果放回操作序列, 方便下一个递归层计算
for (int i = 1; i <= idx1; ++i) op[i + st - 1] = q1[i];
for (int i = 1; i <= idx2; ++i) op[i + st + idx1 - 1] = q2[i];
// 递归分治
solve(st, st + idx1 - 1, l, mid);
solve(st + idx1, ed, mid + 1, r);
}
无修改的区间查询第 k k k小数代码实现

整体二分 + 树状数组实现
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, M = 1e4 + 10, INF = 2e9;
int n, m;
struct Ask {
int pos, val;
int l, r, k, id;
int op;
} op[N << 1], q1[N << 1], q2[N << 1];
int ans[N << 1];
int cnt;
int tr[N];
int lowbit(int x) {
return x & -x;
}
void add(int u, int val) {
for (int i = u; i <= n; i += lowbit(i)) tr[i] += val;
}
int get(int u) {
int ans = 0;
for (int i = u; i; i -= lowbit(i)) ans += tr[i];
return ans;
}
void solve(int st, int ed, int l, int r) {
if (st > ed) return;
if (l == r) {
for (int i = st; i <= ed; ++i) {
if (op[i].op) ans[op[i].id] = l;
}
return;
}
int mid = l + r >> 1;
int idx1 = 0, idx2 = 0;
for (int i = st; i <= ed; ++i) {
if (!op[i].op) {
if (op[i].val <= mid) {
add(op[i].pos, 1);
q1[++idx1] = op[i];
}
else q2[++idx2] = op[i];
}
else {
// 统计区间内<=mid的数字个数
int s = get(op[i].r) - get(op[i].l - 1);
if (s < op[i].k) {
op[i].k -= s;
q2[++idx2] = op[i];
}
else q1[++idx1] = op[i];
}
}
// 清空树状数组
for (int i = 1; i <= idx1; ++i) {
if (!q1[i].op) add(q1[i].pos, -1);
}
for (int i = 1; i <= idx1; ++i) op[i + st - 1] = q1[i];
for (int i = 1; i <= idx2; ++i) op[i + st + idx1 - 1] = q2[i];
solve(st, st + idx1 - 1, l, mid);
solve(st + idx1, ed, mid + 1, r);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int min_val = INF, max_val = -INF;
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
int val;
cin >> val;
op[++cnt] = {
i, val, 0, 0, 0, 0, 0};
min_val = min(min_val, val), max_val = max(max_val, val);
}
for (int i = 1; i <= m; ++i) {
int l, r, k;
cin >> l >> r >> k;
op[++cnt] = {
0, 0, l, r, k, i, 1};
}
solve(1, cnt, min_val, max_val);
for (int i = 1; i <= m; ++i) cout << ans[i] << '\n';
return 0;
}
单点修改的区间第 k k k小数的代码实现

单点修改, 只需要在操作结构体中记录, 假设当前位置是 w i w_i wi, 那当前位置是贡献是 1 1 1, 假设将 i i i位置修改为了 v a l val val, 那么只需要将原来的贡献 − 1 -1 −1, 新的数值产生新的贡献即可, 具体的来说

注意在操作序列修改后, 需要再原数组也同步修改, 因为多了操作 2 × N 2 \times N 2×N, 操作数量最坏情况下是 3 × N 3 \times N 3×N, 因此保险起见将操作序列数组的空间开到 4 × N 4 \times N 4×N
整体二分 + 树状数组实现
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, M = 1e5 + 10, INF = 1e9 + 10;
int n, m, cnt;
int a[N];
// op == 0, x数字位置, y数字的值, k是贡献, op == 1, 查询[x, y]第k大数字, id是查询编号
struct Ask {
int op;
int x, y, id, k;
} q[N << 2], q1[N << 2], q2[N << 2];
int ans[N];
int tr[N];
int lowbit(int x) {
return x & -x;
}
void add(int u, int val) {
for (int i = u; i <= n; i += lowbit(i)) tr[i] += val;
}
int get(int u) {
int ans = 0;
for (int i = u; i; i -= lowbit(i)) ans += tr[i];
return ans;
}
void solve(int st, int ed, int l, int r) {
if (st > ed) return;
if (l == r) {
for (int i = st; i <= ed; ++i) {
if (q[i].op) ans[q[i].id] = l;
}
return;
}
int mid = l + r >> 1;
int idx1 = 0, idx2 = 0;
for (int i = st; i <= ed; ++i) {
if (!q[i].op) {
if (q[i].y <= mid) {
add(q[i].x, q[i].k);
q1[++idx1] = q[i];
}
else q2[++idx2] = q[i];
}
else {
int s = get(q[i].y) - get(q[i].x - 1);
if (s < q[i].k) {
q[i].k -= s;
q2[++idx2] = q[i];
}
else q1[++idx1] = q[i];
}
}
// 重置树状数组
for (int i = 1; i <= idx1; ++i) {
if (!q1[i].op) add(q1[i].x, -q1[i].k);
}
for (int i = 1; i <= idx1; ++i) q[i + st - 1] = q1[i];
for (int i = 1; i <= idx2; ++i) q[i + st + idx1 - 1] = q2[i];
solve(st, st + idx1 - 1, l, mid);
solve(st + idx1, ed, mid + 1, r);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int min_val = INF, max_val = 0;
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
q[++cnt] = {
0, i, a[i], 0, 1};
min_val = min(min_val, a[i]), max_val = max(max_val, a[i]);
}
bool st[N] = {
0};
for (int i = 1; i <= m; ++i) {
char c;
cin >> c;
if (c == 'Q') {
int l, r, k;
cin >> l >> r >> k;
q[++cnt] = {
1, l, r, i, k};
st[i] = true;
}
else {
int x, val;
cin >> x >> val;
q[++cnt] = {
0, x, a[x], 0, -1};
q[++cnt] = {
0, x, val, 0, 1};
a[x] = val;
min_val = min(min_val, val), max_val = max(max_val, val);
}
}
solve(1, cnt, min_val, max_val);
for (int i = 1; i <= m; ++i) {
if (st[i]) cout << ans[i] << '\n';
}
return 0;
}
区间修改的区间第 k k k小数的代码实现

因为涉及到区间修改, 和区间查询, 需要构建线段树, 同时为了清空线段树, 另外加一个延迟标记rest, 代表重置
整体二分 + 线段树实现
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 5e4 + 10, M = 5e4 + 10, INF = 5e5 + 10;
int n, m, ans[N], cnt;
struct Ask {
int op;
int l, r, k;
int id;
} q[N << 1], q1[N << 1], q2[N << 1];
struct Node {
int l, r;
LL sum, add;
int rest;
} tr[N << 2];
void pushup(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u) {
Node &ls = tr[u << 1], &rs = tr[u << 1 | 1];
if (tr[u].rest) {
tr[u].rest = 0;
ls.add = ls.sum = 0;
rs.add = rs.sum = 0;
ls.rest = 1, rs.rest = 1;
}
if (tr[u].add) {
ls.add += tr[u].add;
rs.add += tr[u].add;
ls.sum += (ls.r - ls.l + 1) * tr[u].add;
rs.sum += (rs.r - rs.l + 1) * tr[u].add;
tr[u].add = 0;
}
}
void build(int u, int l, int r) {
tr[u] = {
l, r, 0, 0, 0};
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void modify(int u, int ql, int qr, int val) {
if (tr[u].l >= ql && tr[u].r <= qr) {
tr[u].add += val;
tr[u].sum += (tr[u].r - tr[u].l + 1) * val;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (ql <= mid) modify(u << 1, ql, qr, val);
if (qr > mid) modify(u << 1 | 1, ql, qr, val);
pushup(u);
}
LL query(int u, int ql, int qr) {
if (tr[u].l >= ql && tr[u].r <= qr) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL ans = 0;
if (ql <= mid) ans += query(u << 1, ql, qr);
if (qr > mid) ans += query(u << 1 | 1, ql, qr);
return ans;
}
void solve(int st, int ed, int l, int r) {
if (l == r) {
for (int i = st; i <= ed; ++i) {
if (q[i].op) ans[q[i].id] = l;
}
return;
}
int mid = l + r >> 1;
int idx1 = 0, idx2 = 0;
tr[1].rest = 1;
tr[1].add = tr[1].sum = 0;
for (int i = st; i <= ed; ++i) {
if (!q[i].op) {
if (q[i].k > mid) {
modify(1, q[i].l, q[i].r, 1);
q2[++idx2] = q[i];
}
else q1[++idx1] = q[i];
}
else {
// 统计[l, r]之间大于mid的数字个数
LL t = query(1, q[i].l, q[i].r);
if (t < q[i].k) {
q[i].k -= t;
q1[++idx1] = q[i];
}
else q2[++idx2] = q[i];
}
}
for (int i = 1; i <= idx1; ++i) q[i + st - 1] = q1[i];
for (int i = 1; i <= idx2; ++i) q[i + st + idx1 - 1] = q2[i];
solve(st, st + idx1 - 1, l, mid);
solve(st + idx1, ed, mid + 1, r);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m;
build(1, 1, n);
bool st[N] = {
0};
for (int i = 1; i <= m; ++i) {
int op, l, r, val;
cin >> op >> l >> r >> val;
if (op == 1) q[++cnt] = {
0, l, r, val, 0};
else {
q[++cnt] = {
1, l, r, val, i};
st[i] = true;
}
}
solve(1, cnt, -n, n);
for (int i = 1; i <= m; ++i) {
if (st[i]) cout << ans[i] << '\n';
}
return 0;
}

京公网安备 11010502036488号