目录
引言
本文章介绍的是常用莫队的思路和代码实现
- 基础莫队
- 带修改的莫队
- *回滚莫队
- 树上莫队
- 二次离线莫队
算法能解决的问题
是一种用于高效处理离线区间查询问题(区间不同数字和不同数字某个属性)的算法,尤其适用于静态数组上的区间查询
基础莫队


询问某一段区间中不同数字的数字的个数, 序列长度 5 × 1 0 4 5 \times 10 ^ 4 5×104, 询问个数 2 × 1 0 5 2 \times 10 ^ 5 2×105
假设最坏情况下, 每次查询一个区间 n n n, 开一个数组用于记录每个数字出现的次数, 算法时间复杂度 O ( q n ) O(qn) O(qn), 无法通过
莫队优化
开一个 c n t cnt cnt数组用来记录每个数字出现的次数
( 1 ) (1) (1) 对查询区间进行排序

假设当前查询区间是 [ i , j ] [i, j] [i,j]

蓝色部分是下一段查询的区间
( 2 ) (2) (2)对于每个区间
-
首先移动 j j j指针, 移动到下一次查询的右端点上, 同时对于当前数字 x x x, 如果未出现过那么不同数字的数量 + 1 +1 +1, 否则不同数字的出现次数不发生变化, 同时累计 c n t cnt cnt数组的值
-
再将指针 i i i移动到下一次查询的左端点, 对于当前需要删除的数字 x x x, 如果出现次数大于 1 1 1, 那么 c n t ( x ) − 1 cnt(x) - 1 cnt(x)−1, 不对答案产生影响, 否则答案 − 1 -1 −1
因为每次移动指针最坏情况下是 O ( n ) O(n) O(n)次, 算法时间复杂度最坏 O ( q n ) O(qn) O(qn), 还是没办法通过
算法核心优化:因为算法瓶颈在指针会移动 O ( n q ) O(nq) O(nq)次数, 尝试使得右指针是单调的(不会向回移动), 左指针分块, 具体的来说
区间左端点按照分块的编号排序, 双关键字排序
- 如果分块编号相同按照区间右端点从小到大排序
- 如果分块编号不同, 块小的在前面
将所有查询分为 n \sqrt n n块, 每一块长度是也是 n \sqrt n n, 块内部区间的右端点是递增的
对于右指针来说, 块内走的次数不会超过 n n n, 一共 n \sqrt n n块, 最多移动 n n n \sqrt n nn次
左指针分为两种情况
- 块内最多移动 n \sqrt n n次, 最多 q q q个询问, 算法时间复杂度最差 O ( q n ) O(q \sqrt n) O(qn)
- 块间最多移动 2 n 2 \sqrt n 2n次, 最多跨越 n − 1 \sqrt n - 1 n−1个块, 最差 O ( 2 n ) O(2n) O(2n)
因此左指针的最差时间复杂度是 O ( q n ) O(q \sqrt n) O(qn)
左右时间取最大值, 因此优化后的算法时间复杂度最坏情况下是 O ( q n ) O(q \sqrt n) O(qn)
假设块的大小是 a a a, 右指针的移动次数是 n 2 a \frac{n ^ 2}{a} an2, 左指针的最大移动次数是 m a ma ma, 也就是
a = n 2 m a = \sqrt {\frac{n ^ 2}{m}} a=mn2
左右指针移动的次数相当
如果 a = n a = \sqrt n a=n比较慢, 尝试将 a a a变为 n 2 m \sqrt {\frac{n ^ 2}{m}} mn2
核心代码
// i表示左端点, j表示右端点
for (int k = 0, i = 1, j = 0, res = 0; k < m; ++k) {
// l, r分别代表当前区间的左右端点
auto [l, r, id] = q[k];
while (j < r) add(w[++j], res);
while (j > r) del(w[j--], res);
while (i < l) del(w[i++], res);
while (i > l) add(w[--i], res);
ans[id] = res;
}
示例代码
#include <bits/stdc++.h>
using namespace std;
const int N = 50010, M = 2e5 + 10, S = 1e6 + 10;
int n, m, len;
int w[N], cnt[S], ans[M];
struct Ask {
int l, r, id;
} q[M];
int get(int i) {
return i / len;
}
bool cmp(const Ask &a, const Ask &b) {
int ba = get(a.l), bb = get(b.l);
if (ba == bb) return a.r < b.r;
return ba < bb;
}
void add(int x, int &ans) {
if (!cnt[x]) ans++;
cnt[x]++;
}
void del(int x, int &ans) {
cnt[x]--;
if (!cnt[x]) ans--;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n;
for (int i = 1; i <= n; ++i) cin >> w[i];
cin >> m;
// 计算块长
len = sqrt(n * n / m + 1);
for (int i = 0; i < m; ++i) {
int l, r;
cin >> l >> r;
q[i] = {
l, r, i};
}
sort(q, q + m, cmp);
for (int k = 0, i = 1, j = 0, res = 0; k < m; ++k) {
auto [l, r, id] = q[k];
while (j < r) add(w[++j], res);
while (j > r) del(w[j--], res);
while (i < l) del(w[i++], res);
while (i > l) add(w[--i], res);
ans[id] = res;
}
for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
return 0;
}
带修莫队


因为添加了修改操作, 因此需要加一个维度时间戳

因此询问需要多一个属性, 在第 k k k个修改之后, 第 k + 1 k + 1 k+1个修改之前, 对于当前操作的时间戳是 k k k, 因此询问的属性是三元组 ( l , r , k ) (l, r, k) (l,r,k)
指针 i , j i, j i,j的移动不会影响原序列的数值, 但是时间戳 t t t移动可能导致序列数值变化, 具体的来说

假设 t t t从 k k k变到了 k + 1 k + 1 k+1, 每个操作只会修改一个数
- 假设修改的数字不在区间中, 对 c n t cnt cnt数组不产生影响
- 假设修改的数字在区间中, 做如下操作, 首先在 c n t cnt cnt中减去原来的数, 然后再加上新的数, 操作时间复杂度是 O ( 1 ) O(1) O(1)的
因此 t t t发生变化对时间复杂度不产生影响
假设时间 k k k数值是 x x x, 在 k + 1 k + 1 k+1的数值变为了 x ′ x' x′, 如果需要从时间 k + 1 k + 1 k+1回到 k k k, 需要将数据恢复
具体的来说, 如果在 k + 1 k + 1 k+1时间将数值修改为 x ′ x' x′, 从 k + 1 k + 1 k+1时间回到 k k k时间, swap()
算法核心
使得一个指针单调, 剩下的指针分块, 按照修改时间递增, 区间 l , r l, r l,r分块
设块的大小是 a a a, 块的数量是 n a \frac{n}{a} an
- l l l指针, 块内移动数量是 m a ma ma, 块间移动是 2 a ⋅ n a 2a \cdot \frac{n}{a} 2a⋅an, 移动次数 O ( m a + n ) O(ma + n) O(ma+n)
- t t t指针, t t t递增, l l l有 n a \frac{n}{a} an块, r r r同理, 移动次数 O ( t ⋅ n a n a ) O(t \cdot \frac{n}{a} \frac{n}{a}) O(t⋅anan)
- l l l固定, r r r块内移动最差 m a ma ma, 块间移动是 O ( n ) O(n) O(n), 对于每个 l l l来说, r r r都是移动 O ( n ) O(n) O(n), 因此块间 O ( n 2 a ) O(\frac{n ^ 2}{a}) O(an2), 总的移动次数 O ( m a + n 2 a ) O(ma + \frac{n ^ 2}{a}) O(ma+an2)
块大小取 n t 3 \sqrt[3]{nt} 3nt
核心代码
bool cmp(const Ask &a, const Ask &b) {
int al = get(a.l), ar = get(a.r);
int bl = get(b.l), br = get(b.r);
// 先按照区间的左端点所在的块编号排
if (al != bl) return al < bl;
// 如果左端点块相等, 按照右端点所在的块排
if (ar != br) return ar < br;
// 都相等说明两个区间在一个块内, 按照时间戳排
return a.t < b.t;
}
int res = 0;
for (int k = 1, i = 1, j = 0, t = 0; k <= qc; ++k) {
int l = q[k].l, r = q[k].r, tm = q[k].t, id = q[k].id;
while (j < r) add(w[++j], res);
while (j > r) del(w[j--], res);
while (i < l) del(w[i++], res);
while (i > l) add(w[--i], res);
//当前时间戳 < tm, 需要向上移动
while (t < tm) {
t++;
// 如果修改位置落在了当前区间, 那么执行修改
if (c[t].p >= i && c[t].p <= j) {
del(w[c[t].p], res);
add(c[t].c, res);
}
swap(w[c[t].p], c[t].c);
}
// 当前时间戳大于tm, 向下移动
while (t > tm) {
if (c[t].p >= i && c[t].p <= j) {
del(w[c[t].p], res);
add(c[t].c, res);
}
// 修改前的数值被交换到了c[t]中, 如果得到, 直接swap回来
swap(w[c[t].p], c[t].c);
t--;
}
ans[id] = res;
}
示例代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, M = 1e5 + 10, S = 1e6 + 10;
int n, m, len, cc, qc;
int w[N], ans[M], cnt[S];
struct Ask {
int l, r, id, t;
} q[M];
struct M {
int p, c;
} c[M];
int get(int i) {
return i / len;
}
bool cmp(const Ask &a, const Ask &b) {
int al = get(a.l), ar = get(a.r);
int bl = get(b.l), br = get(b.r);
if (al != bl) return al < bl;
if (ar != br) return ar < br;
return a.t < b.t;
}
void add(int x, int &ans) {
if (!cnt[x]) ans++;
cnt[x]++;
}
void del(int x, int &ans) {
cnt[x]--;
if (!cnt[x]) ans--;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> w[i];
for (int i = 0; i < m; ++i) {
int a, b;
char ch;
cin >> ch >> a >> b;
if (ch == 'Q') q[++qc] = {
a, b, qc, cc};
else c[++cc] = {
a, b};
}
// 计算块的长度
len = cbrt((double) n * max(1, cc)) + 1;
// 如果查询区间左右端点在同一个块内, 按照时间戳顺序排序
sort(q + 1, q + qc + 1, cmp);
int res = 0;
for (int k = 1, i = 1, j = 0, t = 0; k <= qc; ++k) {
int l = q[k].l, r = q[k].r, tm = q[k].t, id = q[k].id;
while (j < r) add(w[++j], res);
while (j > r) del(w[j--], res);
while (i < l) del(w[i++], res);
while (i > l) add(w[--i], res);
//当前时间戳 < tm, 需要向上移动
while (t < tm) {
t++;
// 如果修改位置落在了当前区间, 那么执行修改
if (c[t].p >= i && c[t].p <= j) {
del(w[c[t].p], res);
add(c[t].c, res);
}
swap(w[c[t].p], c[t].c);
}
// 当前时间戳大于tm, 向下移动
while (t > tm) {
if (c[t].p >= i && c[t].p <= j) {
del(w[c[t].p], res);
add(c[t].c, res);
}
// 修改前的数值被交换到了c[t]中, 如果得到, 直接swap回来
swap(w[c[t].p], c[t].c);
t--;
}
ans[id] = res;
}
for (int i = 1; i <= qc; ++i) cout << ans[i] << '\n';
return 0;
}
回滚莫队

每次询问一个区间 [ l , r ] [l, r] [l,r], 很多种不同的数
, 对于每个数 x x x, 假设出现次数是 t t t, 每个数的重要度等于 t ⋅ x t \cdot x t⋅x
希望求出区间内重要度的最大值
插入一个数求最大值直接取 m a x max max, 但是删除后很难维护最大值(可能需要开一个大根堆)
具体的算法步骤
(1) r r r单调, l l l分块进行排序
(2) 考虑块内如何计算

假设当前考虑的是块 1 1 1, 左端点一定在块 1 1 1
- 假设右端点也在块 1 1 1, 直接暴力计算( O ( n ) O(\sqrt n) O(n))所有询问
- 剩下的询问, 右端点在后面

将区间分为两部分, 如果是右边部分, 因为 r r r是递增的, 因此直接插入

但是, c n t , r e s cnt, res cnt,res维护的不是整个查询区间, 而是从下一个块开始的信息, 因为左边的长度 ≤ n \le \sqrt n ≤n, 因此左边部分暴力计算
左边暴力加之后, 再暴力回滚回去
注意数据范围很大, 需要开 l o n g l o n g long \; long longlong, 并且需要离散化
核心代码
void solve() {
for (int x = 0; x < m; ) {
int y = x;
// x, y之间的部分就是当前块的部分
while (y < m && get(q[y].l) == get(q[x].l)) y++;
// 当前块的尾端点(块的第一个位置 + 块长 - 1)
int rb = get(q[x].l) * len + len - 1;
// 暴力求查询区间右端点在块内
while (x < y && q[x].r <= rb) {
LL res = 0;
auto [l, r, id] = q[x];
for (int i = l; i <= r; ++i) add(w[i], res);
ans[id] = res;
// 每个询问的cnt数组产生的贡献不同需要将当前查询区间的影响减去
for (int i = l; i <= r; ++i) cnt[w[i]]--;
x++;
}
LL res = 0;
int i = rb + 1, j = rb;
while (x < y) {
auto [l, r, id] = q[x];
// 将区间右端点移动到r
while (j < r) add(w[++j], res);
LL backup = res;
// 初始状态在块外部, 移动到块内的同时累计答案, 并且提前备份res, 计算好结果后恢复
while (i > l) add(w[--i], res);
ans[id] = res;
// 将加入的数删除
while (i < rb + 1) cnt[w[i++]]--;
res = backup;
x++;
}
memset(cnt, 0, sizeof cnt);
}
for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
}
示例代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10, M = 1e5 + 10;
int n, m, len;
int w[N], cnt[N];
LL ans[M];
vector<int> vec;
struct Ask {
int l, r, id;
} q[M];
int get(int x) {
return x / len;
}
int find(int x) {
return lower_bound(vec.begin(), vec.end(), x) - vec.begin();
}
bool cmp(const Ask &a, const Ask &b) {
int ba = get(a.l), bb = get(b.l);
if (ba != bb) return ba < bb;
return a.r < b.r;
}
void add(int x, LL &ans) {
cnt[x]++;
ans = max(ans, (LL) cnt[x] * vec[x]);
}
void solve() {
for (int x = 0; x < m; ) {
int y = x;
// x, y之间的部分就是当前块的部分
while (y < m && get(q[y].l) == get(q[x].l)) y++;
// 当前块的尾端点(块的第一个位置 + 块长 - 1)
int rb = get(q[x].l) * len + len - 1;
// 暴力求查询区间右端点在块内
while (x < y && q[x].r <= rb) {
LL res = 0;
auto [l, r, id] = q[x];
for (int i = l; i <= r; ++i) add(w[i], res);
ans[id] = res;
// 每个询问的cnt数组产生的贡献不同需要将当前查询区间的影响减去
for (int i = l; i <= r; ++i) cnt[w[i]]--;
x++;
}
LL res = 0;
int i = rb + 1, j = rb;
while (x < y) {
auto [l, r, id] = q[x];
// 将区间右端点移动到r
while (j < r) add(w[++j], res);
LL backup = res;
// 初始状态在块外部, 移动到块内的同时累计答案, 并且提前备份res, 计算好结果后恢复
while (i > l) add(w[--i], res);
ans[id] = res;
// 将加入的数删除
while (i < rb + 1) cnt[w[i++]]--;
res = backup;
x++;
}
memset(cnt, 0, sizeof cnt);
}
for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m;
len = sqrt(n);
for (int i = 1; i <= n; ++i) {
cin >> w[i];
vec.push_back(w[i]);
}
// 原数组进行离散化
sort(vec.begin(), vec.end());
vec.erase(unique(vec.begin(), vec.end()), vec.end());
for (int i = 1; i <= n; ++i) w[i] = find(w[i]);
for (int i = 0; i < m; ++i) {
int l, r;
cin >> l >> r;
q[i] = {
l, r, i};
}
sort(q, q + m, cmp);
solve();
return 0;
}
树上莫队

求 u u u到 v v v上有多少个不同的
首先是暴力做法, 假设当当前询问的路径长度是 n n n, 统计不同数字出现的次数, 可以做到枚举一遍路径上的点就知道不同数字的个数(开一个桶, 记录每个数字出现的次数, 如果重复出现答案 − 1 -1 −1), 一共 m m m个询问, 算法时间复杂度 O ( n m ) O(nm) O(nm), 会超时
尝试将树上问题变为区间问题, 也就是将树的序列取出, 这里用到的是欧拉序列

欧拉序列就是递归到当前点的时候记录一遍, 返回的时候再记录一遍
对于上图的欧拉序列是
1 2 2 3 5 5 6 6 7 7 3 4 8 8 4 1 1 \; 2\; 2\; 3\; 5\; 5\; 6\; 6\; 7\; 7\; 3\; 4\; 8\; 8\; 4\; 1 1223556677348841

观察树上路径 1 ⟶ 8 1 \longrightarrow 8 1⟶8
选取路径上 1 1 1最后出现的位置和 8 8 8最开始出现的位置, 发现是
1 2 2 3 5 5 6 6 7 7 3 4 8 1 \; 2\; 2\; 3\; 5\; 5\; 6\; 6\; 7\; 7\; 3\; 4\; 8\; 1223556677348
得到结论
欧拉路径上出现次数是 1 1 1的点是目标路径上的点, 因为访问过没返回
在欧拉序列中定义 f i f_i fi表示节点 i i i第一次出现的位置, e i e_i ei表示节点 i i i最后一次出现的位置
对于树上的任意节点 x , y x, y x,y, 假设 f x < f y f_x < f_y fx<fy
(1) 假设 l c a ( x , y ) = x lca(x, y) = x lca(x,y)=x, 那么在欧拉序列 [ f x , f y ] [f_x, f_y] [fx,fy]上出现一次的点的集合就是目标路径

(2) 假设 l c a ( x , y ) ≠ x lca(x, y) \ne x lca(x,y)=x, 那么在欧拉序列中 [ e x , f y ] [e_x, f_y] [ex,fy]上出现一次的点以及最近公共祖先 l c a ( x , y ) lca(x, y) lca(x,y), 上图中应该是 3 ′ 3' 3′步的时候计算路径
这样就可以将树上询问变为区间询问

对于这样一段欧拉序列, 求的是只出现一次的数里面有多少个数是不同的
算法步骤
- 对节点权值离散化
- 求树的欧拉序列
- 求 l c a lca lca, d e p t h u depth_u depthu代表当前点的深度, f ( u , k ) f(u, k) f(u,k)表示从当前点 u u u向上走 2 k 2 ^ k 2k步的祖先, 可以通过递推计算, 先跳 2 k − 1 2 ^ {k - 1} 2k−1再跳 2 k − 1 2 ^ {k - 1} 2k−1步, 得到递推式 f ( u , k ) = f ( f ( u , k − 1 ) , k − 1 ) f(u, k) = f(f(u, k - 1), k - 1) f(u,k)=f(f(u,k−1),k−1)
预处理这两个数组之后, 假设 d e p t h a < d e p t h b depth_a < depth_b deptha<depthb, 首先将 a a a跳到 b b b的高度, 然后同时向上跳 - 将树的询问变为序列的询问
- 莫队算法求区间不同数字的个数, c n t cnt cnt表示权值出现的次数, s t st st表示节点出现的次数
算法时间瓶颈在莫队, 算法时间复杂度 O ( n n ) O(n \sqrt n) O(nn)
核心代码
int res = 0;
for (int k = 0, i = 1, j = 0; k < m; ++k) {
int l = q[k].l, r = q[k].r, id = q[k].id, p = q[k].p;
while (j < r) add(seq[++j], res);
while (j > r) add(seq[j--], res);
while (i < l) add(seq[i++], res);
while (i > l) add(seq[--i], res);
if (p) add(p, res);
ans[id] = res;
if (p) add(p, res);
}#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, M = 1e5 + 10;
int n, m, len;
int w[N];
int head[N], ed[N], ne[N], idx;
int depth[N], f[N][17];
int seq[N], ptr, fst[N], lst[N];
int cnt[N], st[N], ans[M];
struct Ask {
int l, r, id;
int p;
} q[M];
vector<int> vec;
int que[N], h, t;
int find(int x) {
return lower_bound(vec.begin(), vec.end(), x) - vec.begin();
}
void add_edge(int u, int v) {
ed[idx] = v, ne[idx] = head[u], head[u] = idx++;
}
// 求欧拉序列
void dfs(int u, int fa) {
seq[++ptr] = u;
fst[u] = ptr;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == fa) continue;
dfs(v, u);
}
seq[++ptr] = u;
lst[u] = ptr;
}
void bfs() {
memset(depth, 0x3f, sizeof depth);
h = 0, t = -1;
depth[0] = 0, depth[1] = 1;
que[++t] = 1;
while (h <= t) {
int u = que[h++];
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (depth[u] + 1 < depth[v]) {
depth[v] = depth[u] + 1;
f[v][0] = u;
for (int k = 1; k <= 16; ++k) f[v][k] = f[f[v][k - 1]][k - 1];
que[++t] = v;
}
}
}
}
int lca(int a, int b) {
if (depth[a] < depth[b]) swap(a, b);
for (int k = 16; k >= 0; --k) {
if (depth[f[a][k]] >= depth[b]) a = f[a][k];
}
if (a == b) return a;
for (int k = 16; k >= 0; --k) {
if (f[a][k] != f[b][k]) {
a = f[a][k];
b = f[b][k];
}
}
return f[a][0];
}
int get(int i) {
return i / len;
}
bool cmp(const Ask &a, const Ask &b) {
int ba = get(a.l), bb = get(b.l);
if (ba != bb) return ba < bb;
return a.r < b.r;
}
void add(int x, int &ans) {
st[x] ^= 1;
// x节点出现了偶数次
if (!st[x]) {
cnt[w[x]]--;
if (!cnt[w[x]]) ans--;
}
else {
if (!cnt[w[x]]) ans++;
cnt[w[x]]++;
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
memset(head, -1, sizeof head);
cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> w[i], vec.push_back(w[i]);
sort(vec.begin(), vec.end());
vec.erase(unique(vec.begin(), vec.end()), vec.end());
for (int i = 1; i <= n; ++i) w[i] = find(w[i]);
for (int i = 0; i < n - 1; ++i) {
int a, b;
cin >> a >> b;
add_edge(a, b), add_edge(b, a);
}
// 处理欧拉序列
dfs(1, -1);
// 处理lca
bfs();
for (int i = 0; i < m; ++i) {
int a, b;
cin >> a >> b;
if (fst[a] > fst[b]) swap(a, b);
int p = lca(a, b);
if (p == a) q[i] = {
fst[a], fst[b], i, 0};
else q[i] = {
lst[a], fst[b], i, p};
}
len = sqrt(ptr);
sort(q, q + m, cmp);
int res = 0;
for (int k = 0, i = 1, j = 0; k < m; ++k) {
auto [l, r, id, p] = q[k];
while (j < r) add(seq[++j], res);
while (j > r) add(seq[j--], res);
while (i < l) add(seq[i++], res);
while (i > l) add(seq[--i], res);
if (p) add(p, res);
ans[id] = res;
if (p) add(p, res);
}
for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
return 0;
}
示例代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, M = 1e5 + 10;
int n, m, len;
int w[N];
int head[N], ed[N], ne[N], idx;
int depth[N], f[N][17];
int seq[N], ptr, fst[N], lst[N];
int cnt[N], st[N], ans[M];
struct Ask {
int l, r, id;
int p;
} q[M];
vector<int> vec;
int que[N], h, t;
int find(int x) {
return lower_bound(vec.begin(), vec.end(), x) - vec.begin();
}
void add_edge(int u, int v) {
ed[idx] = v, ne[idx] = head[u], head[u] = idx++;
}
// 求欧拉序列
void dfs(int u, int fa) {
seq[++ptr] = u;
fst[u] = ptr;
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (v == fa) continue;
dfs(v, u);
}
seq[++ptr] = u;
lst[u] = ptr;
}
void bfs() {
memset(depth, 0x3f, sizeof depth);
h = 0, t = -1;
depth[0] = 0, depth[1] = 1;
que[++t] = 1;
while (h <= t) {
int u = que[h++];
for (int i = head[u]; ~i; i = ne[i]) {
int v = ed[i];
if (depth[u] + 1 < depth[v]) {
depth[v] = depth[u] + 1;
f[v][0] = u;
for (int k = 1; k <= 16; ++k) f[v][k] = f[f[v][k - 1]][k - 1];
que[++t] = v;
}
}
}
}
int lca(int a, int b) {
if (depth[a] < depth[b]) swap(a, b);
for (int k = 16; k >= 0; --k) {
if (depth[f[a][k]] >= depth[b]) a = f[a][k];
}
if (a == b) return a;
for (int k = 16; k >= 0; --k) {
if (f[a][k] != f[b][k]) {
a = f[a][k];
b = f[b][k];
}
}
return f[a][0];
}
int get(int i) {
return i / len;
}
bool cmp(const Ask &a, const Ask &b) {
int ba = get(a.l), bb = get(b.l);
if (ba != bb) return ba < bb;
return a.r < b.r;
}
void add(int x, int &ans) {
st[x] ^= 1;
// x节点出现了偶数次
if (!st[x]) {
cnt[w[x]]--;
if (!cnt[w[x]]) ans--;
}
else {
if (!cnt[w[x]]) ans++;
cnt[w[x]]++;
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
memset(head, -1, sizeof head);
cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> w[i], vec.push_back(w[i]);
sort(vec.begin(), vec.end());
vec.erase(unique(vec.begin(), vec.end()), vec.end());
for (int i = 1; i <= n; ++i) w[i] = find(w[i]);
for (int i = 0; i < n - 1; ++i) {
int a, b;
cin >> a >> b;
add_edge(a, b), add_edge(b, a);
}
// 处理欧拉序列
dfs(1, -1);
// 处理lca
bfs();
for (int i = 0; i < m; ++i) {
int a, b;
cin >> a >> b;
if (fst[a] > fst[b]) swap(a, b);
int p = lca(a, b);
if (p == a) q[i] = {
fst[a], fst[b], i, 0};
else q[i] = {
lst[a], fst[b], i, p};
}
len = sqrt(ptr);
sort(q, q + m, cmp);
int res = 0;
for (int k = 0, i = 1, j = 0; k < m; ++k) {
int l = q[k].l, r = q[k].r, id = q[k].id, p = q[k].p;
while (j < r) add(seq[++j], res);
while (j > r) add(seq[j--], res);
while (i < l) add(seq[i++], res);
while (i > l) add(seq[--i], res);
if (p) add(p, res);
ans[id] = res;
if (p) add(p, res);
}
for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
return 0;
}
二次离线莫队(Mo’s Algorithm, Secondary Offline)

如果直接用基础莫队每次 a d d ( x ) add(x) add(x), 需要遍历整个区间, 算法时间复杂度会退化到 O ( n 2 ) O(n ^ 2) O(n2)
算法精髓: 莫队移动指针时那些难以计算的 O ( n ) O(n) O(n)级别的修改操作,全部“打包”存起来,等到最后统一处理
(1) 第一步操作
要求 a i ⊕ a j a_i \oplus a_j ai⊕aj有 k k k个 1 1 1
不妨设 a i ⊕ A = Y a_i \oplus A = Y ai⊕A=Y, 并且 Y Y Y的二进制表示下有 k k k个 1 1 1, 由数学变换, A = Y ⊕ a i A = Y \oplus a_i A=Y⊕ai
也就是对于一个数 a i a_i ai, 所有的 A = Y ⊕ a i A = Y \oplus a_i A=Y⊕ai这样的数字都是与其配对的结果
代码表示为
vector<int> nums;
for (int i = 0; i < 1 << 14; i++) // 枚举所有可能的异或结果Y
if (get_count(i) == k)
nums.push_back(i); // nums存下了所有二进制中包含k个1的数字
(2)第二步操作-外层套基础莫队, 对查询进行排序
(3)第三步操作-利用前缀和思想
当我们加入一个新元素 w i w_i wi 时,我们需要知道区间 [ l , r ] [l, r] [l,r] 中有多少个数和它配对
C n t ( [ l , r ] , w i ) = C n t ( [ 1 , r ] , w i ) − C n t ( [ 1 , l − 1 ] , w i ) Cnt([l, r], w_i) = Cnt([1, r], w_i) - Cnt([1, l - 1], w_i) Cnt([l,r],wi)=Cnt([1,r],wi)−Cnt([1,l−1],wi)
也就是说,我们可以预先算前缀中与某个数配对的数量,然后用减法得到区间结果
代码表示为
for (int i = 1; i <= n; i++) {
for (auto y: nums) ++g[w[i] ^ y]; // 1. 统计:w[i]的出现,会让数值为 (w[i]^y) 的需求+1
f[i] = g[w[i + 1]]; // 2. 记录:f[i] 表示在处理完前i个数后,有多少数能和 w[i+1] 配对
}
f i f_i fi 的含义非常关键:它表示在序列的前 i i i 个位置中,有多少个数 w j w_j wj满足 w j ⊕ w i + 1 w_j \oplus w_{i + 1} wj⊕wi+1的结果有 k k k 个 1 1 1
(4) 第四步操作-莫队移动指针算简单的任务, 记录复杂的任务
右指针向右移动为例
-
因为 f R f_R fR 算的是 [ 1 , R ] [1, R] [1,R] 的贡献,而我们实际需要的是 [ L , R ] [L, R] [L,R] 的贡献
-
多算的部分是 [ 1 , L − 1 ] [1, L-1] [1,L−1] 这一段。这段的贡献我们暂时不减,而是把它记在一个任务列表
range[L-1]里(因为难以计算)
(5)二次离线处理(Secondary Offline)
代码表示为
memset(g, 0, sizeof g);
for (int i = 1; i <= n; i++) {
// 1. 更新前缀统计:把位置 i 的数 w[i] 加入统计
for (auto y: nums) ++g[w[i] ^ y];
// 2. 处理任务:检查有没有人需要 [1, i] 这个前缀的信息
for (auto &rg: range[i]) {
int id = rg.id, l = rg.l, r = rg.r, t = rg.t;
for (int x = l; x <= r; x++)
q[id].res += g[w[x]] * t; // 查询 w[x] 在当前前缀 [1,i] 中的配对数量
}
}
这里 g [ x ] g[x] g[x] 的含义是:在当前扫描到的前缀中,有多少个数与数值 x x x 异或后结果有 k k k 个 1 1 1
当我们扫描到位置 i i i 时,所有挂起的任务 r a n g e [ i ] range[i] range[i](即需要前缀 [ 1 , i ] [1, i] [1,i] 信息的任务)就可以被执行了
预处理时间复杂度 O ( C 14 7 ⋅ n ) O(C_{14} ^ 7 \cdot n) O(C147⋅n), 莫队时间复杂度 O ( n n ) O(n\sqrt n) O(nn), 算法时间复杂度 O ( n ⋅ C 14 7 + n n ) O(n\cdot C_{14} ^ 7 + n \sqrt n) O(n⋅C147+nn)
核心代码
for (int t = 0, i = 1, j = 0; t < m; ++t) {
int l = q[t].l, r = q[t].r;
// j指针向右扩张, tsk[i - 1]记录未来[j + 1, r]与[1, i - 1]造成的影响, 也就是配对的数量
if (j < r) tsk[i - 1].push_back({
j + 1, r, t, -1});
while (j < r) q[t].ans += f[j++];
// j指针向左收缩, tsk[i - 1]记录未来[r + 1, j]与[1, i - 1]造成的影响, 也就是配对的数量
if (j > r) tsk[i - 1].push_back({
r + 1, j, t, 1});
while (j > r) q[t].ans -= f[--j];
// i指针向右收缩, 对于当前移除的元素w[i], 希望计算对区间[i + 1, j]的影响, w[i]造成的影响值是f[i - 1]
if (i < l) tsk[j].push_back({
i, l - 1, t, -1});
while (i < l) q[t].ans += f[i - 1] + !k, i++;
// i指针向左扩张, 相当于添加当前元素w[i - 1], 希望计算对区间[i, j]的影响, w[i - 1]造成的影响是f[i - 2]
if (i > l) tsk[j].push_back({
l, i - 1, t, 1});
while (i > l) q[t].ans -= f[i - 2] + !k, i--;
}
memset(g, 0, sizeof g);
for (int i = 1; i <= n; ++i) {
for (auto y : vec) g[w[i] ^ y]++;
for (auto &[l, r, id, t] : tsk[i]) {
for (int x = l; x <= r; ++x) q[id].ans += g[w[x]] * t;
}
}
如何理解 i i i指针?
f [ x ] f[x] f[x]的定义是 w [ x + 1 ] w[x + 1] w[x+1]与前缀 [ 1 , x ] [1, x] [1,x]产生的配对数量
w [ i ] w[i] w[i]产生的影响是 f [ i − 1 ] f[i - 1] f[i−1], w [ i − 1 ] w[i - 1] w[i−1]产生的影响是 f [ i − 2 ] f[i - 2] f[i−2]
示例代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10, M = 1e5 + 10;
int n, m, k, len;
int w[N];
LL ans[N];
struct Ask {
int l, r, id;
LL ans;
} q[M];
struct Tsk {
int l, r, id, t;
};
vector<Tsk> tsk[N];
int f[N], g[N];
inline int get_cnt(int x) {
int ans = 0;
while (x) ans += x & 1, x >>= 1;
return ans;
}
inline int get(int i) {
return i / len;
}
bool cmp(const Ask &a, const Ask &b) {
int ba = get(a.l), bb = get(b.l);
if (ba != bb) return ba < bb;
return a.r < b.r;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m >> k;
for (int i = 1; i <= n; ++i) cin >> w[i];
vector<int> vec;
for (int i = 0; i < 1 << 14; ++i) {
if (get_cnt(i) == k) vec.push_back(i);
}
for (int i = 1; i <= n; ++i) {
for (int x : vec) g[w[i] ^ x]++;
f[i] = g[w[i + 1]];
}
for (int i = 0; i < m; ++i) {
int l, r;
cin >> l >> r;
q[i] = {
l, r, i, 0};
}
len = sqrt(n);
sort(q, q + m, cmp);
for (int idx = 0, i = 1, j = 0; idx < m; ++idx) {
int l = q[idx].l, r = q[idx].r;
if (j < r) tsk[i - 1].push_back({
j + 1, r, idx, -1});
while (j < r) q[idx].ans += f[j++];
if (j > r) tsk[i - 1].push_back({
r + 1, j, idx, 1});
while (j > r) q[idx].ans -= f[--j];
if (i < l) tsk[j].push_back({
i, l - 1, idx, -1});
while (i < l) q[idx].ans += f[i - 1] + !k, i++;
if (i > l) tsk[j].push_back({
l, i - 1, idx, 1});
while (i > l) q[idx].ans -= f[i - 2] + !k, i--;
}
memset(g, 0, sizeof g);
for (int i = 1; i <= n; ++i) {
for (int x : vec) g[w[i] ^ x]++;
for (auto &[l, r, id, t] : tsk[i]) {
for (int x = l; x <= r; ++x) {
q[id].ans += g[w[x]] * t;
}
}
}
for (int i = 1; i < m; i++) q[i].ans += q[i - 1].ans;
for (int i = 0; i < m; i++) ans[q[i].id] = q[i].ans;
for (int i = 0; i < m; i++) cout << ans[i] << '\n';
return 0;
}

京公网安备 11010502036488号