*数据结构模板

题目

530. 列队
在这里插入图片描述

算法标签: 模拟, 线段树, 线段树动态开点, 树状数组, 平衡树

思路

首先考虑简单情况, 如果只有一行, 删除一个位置, 然后在后面再添加该位置, 以及查询第 i i i个未被删除的位置, 可以使用线段树进行维护, 线段树维护管理区间中有多少个数已经被删除

但是对于矩阵的情况不仅仅需要从右向左合并还需要从下到上合并, 并且注意到对于点 ( x , y ) (x, y) (x,y)进行操作, 只会影响到第 x x x行和第 m m m列, 因此可以每一行维护线段树, 最后一列也也维护线段树(每一行只维护 m − 1 m - 1 m1个数), 但是还需要回到集合中, 需要开一个数组存储

因为线段树需要开 4 4 4倍空间, 并且每一行都需要开线段树, 直接开内存一定会爆炸, 因此需要对线段树进行动态开点

在这里插入图片描述
假如当前删除的位置是 ( x , y ) (x, y) (x,y), 对行线段树执行单点删除, 然后将删除的数顺次加入到行数组中, 添加的数就是第 m m m列的从上向下数第 x x x未被删除的数, 再执行查询操作, 然后再对最后一列的线段树和数组执行删除和添加操作

*前置代码

计算动态开点需要的节点数

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;

typedef long long LL;
const int N = 300010;

int build(int l, int r, int depth) {
   
	if (l == r) return depth;
	int mid = l + r >> 1;
	return max(build(l, mid, depth + 1), build(mid + 1, r, depth + 1));
}

int main() {
   
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);

	int depth = build(1, 600000, 1);

	cout << depth << "\n";
	return 0;
}

每次询问最多更改 21 21 21层, 一共 3 × 1 0 5 3 \times 10 ^ 5 3×105个询问, 因此点数 Q × 21 Q \times 21 Q×21

完整注释代码

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>

using namespace std;

typedef long long LL;

const int N = 300010, M = N * 21; // N是最大行数+Q,M是线段树节点数

int n, m, Q;
// 线段树节点,记录左右儿子和被删除的点的数量
struct Node {
   
	int ls, rs, cnt;
} tr[M];
// root数组存储每行的线段树根节点,idx是动态开点的索引
int root[N], idx;
// g数组存储每行和最后一列的动态添加的学生编号
vector<LL> g[N];

// 在线段树中查询第x个未被删除的位置
int query(int &u, int l, int r, int x) {
   
	if (!u) u = ++idx; // 动态开点
	if (l == r) return r; // 找到目标位置
	int mid = l + r >> 1;
	int left = mid - l + 1 - tr[tr[u].ls].cnt; // 左区间未被删除的数量
	if (x <= left) return query(tr[u].ls, l, mid, x);
	return query(tr[u].rs, mid + 1, r, x - left);
}

// 在线段树中删除位置x
void update(int &u, int l, int r, int x) {
   
	if (l == r) {
   
		tr[u].cnt++; // 标记为已删除
	}
	else {
   
		int mid = l + r >> 1;
		if (x <= mid) update(tr[u].ls, l, mid, x);
		else update(tr[u].rs, mid + 1, r, x);
		tr[u].cnt = tr[tr[u].ls].cnt + tr[tr[u].rs].cnt; // 更新删除数量
	}
}

int main() {
   
	scanf("%d%d%d", &n, &m, &Q);
	int L = max(n, m) + Q; // 线段树的最大长度

	while (Q--) {
   
		int x, y;
		scanf("%d%d", &x, &y);

		if (y == m) {
   
			// 处理最后一列
			int r = query(root[n + 1], 1, L, x); // 找到第x个未被删除的位置
			update(root[n + 1], 1, L, r); // 删除该位置
			LL id;
			if (r <= n) {
   
				id = (r - 1LL) * m + m; // 原始编号
			}
			else {
   
				id = g[n + 1][r - n - 1]; // 动态添加的编号
			}
			printf("%lld\n", id);
			g[n + 1].push_back(id); // 添加到最后一列的末尾
		}
		else {
   
			// 处理第x行的前m-1列
			int c = query(root[x], 1, L, y); // 找到第y个未被删除的位置
			update(root[x], 1, L, c); // 删除该位置
			LL id;
			if (c < m) {
   
				id = (x - 1LL) * m + c; // 原始编号
			}
			else {
   
				id = g[x][c - m]; // 动态添加的编号
			}
			printf("%lld\n", id);

			// 将离队的学生添加到最后一列
			int r = query(root[n + 1], 1, L, x); // 找到第x个未被删除的位置
			update(root[n + 1], 1, L, r); // 删除该位置
			LL last_id;
			if (r <= n) {
   
				last_id = (r - 1LL) * m + m; // 原始编号
			}
			else {
   
				last_id = g[n + 1][r - n - 1]; // 动态添加的编号
			}
			g[x].push_back(last_id); // 将该学生添加到第x行的末尾
			g[n + 1].push_back(id); // 将离队的学生添加到最后一列的末尾
		}
	}

	return 0;
}

精简注释代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;

typedef long long LL;
const int N = 300010, M = N * 22;

int n, m, q;
struct Node {
   
	int ls, rs, cnt;
} tr[M];
int root[N], idx;
vector<LL> g[N];

int query(int &u, int l, int r, int x) {
   
	if (!u) u = ++idx;
	if (l == r) return r;
	int mid = l + r >> 1;

	// 左区间未删除的数量
	int l_cnt = mid - l + 1 - tr[tr[u].ls].cnt;
	if (x <= l_cnt) return query(tr[u].ls, l, mid, x);
	return query(tr[u].rs, mid + 1, r, x - l_cnt);
}

void remove(int &u, int l, int r, int x) {
   
	if (l == r) {
   
		tr[u].cnt++;
		return;
	}
	int mid = l + r >> 1;
	if (x <= mid) remove(tr[u].ls, l, mid, x);
	else remove(tr[u].rs, mid + 1, r, x);
	tr[u].cnt = tr[tr[u].ls].cnt + tr[tr[u].rs].cnt;
}

int main() {
   
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);

	cin >> n >> m >> q;
	// 线段树最大长度
	int rb = max(n, m) + q;
	while (q--) {
   
		int x, y;
		cin >> x >> y;

		// 处理最后一列
		if (y == m) {
   
			int r = query(root[n + 1], 1, rb, x);
			remove(root[n + 1], 1, rb, r);
			LL id = r <= n ? (r - 1ll) * m + m : g[n + 1][r - n - 1];
			cout << id << "\n";
			g[n + 1].push_back(id);
		}
		else {
   
			int c = query(root[x], 1, rb, y);
			remove(root[x], 1, rb, c);
			LL id = c < m ? (x - 1ll) * m + c : g[x][c - m];

			cout << id << "\n";

			int r = query(root[n + 1], 1, rb, x);
			remove(root[n + 1], 1, rb, r);
			LL n_id = r <= n ? (r - 1ll) * m + m : g[n + 1][r - n - 1];
			g[x].push_back(n_id);
			g[n + 1].push_back(id);
		}
	}

	return 0;
}