引言

本文章介绍的是常用莫队的思路和代码实现

  • 基础莫队
  • 带修改的莫队
  • *回滚莫队
  • 树上莫队
  • 二次离线莫队

算法能解决的问题

是一种用于高效处理离线区间查询问题(区间不同数字和不同数字某个属性)的算法,尤其适用于静态数组上的区间查询

基础莫队

询问某一段区间中不同数字的数字的个数, 序列长度 5 × 1 0 4 5 \times 10 ^ 4 5×104, 询问个数 2 × 1 0 5 2 \times 10 ^ 5 2×105

假设最坏情况下, 每次查询一个区间 n n n, 开一个数组用于记录每个数字出现的次数, 算法时间复杂度 O ( q n ) O(qn) O(qn), 无法通过

莫队优化

开一个 c n t cnt cnt数组用来记录每个数字出现的次数

( 1 ) (1) (1) 对查询区间进行排序


假设当前查询区间是 [ i , j ] [i, j] [i,j]


蓝色部分是下一段查询的区间

( 2 ) (2) (2)对于每个区间

  1. 首先移动 j j j指针, 移动到下一次查询的右端点上, 同时对于当前数字 x x x, 如果未出现过那么不同数字的数量 + 1 +1 +1, 否则不同数字的出现次数不发生变化, 同时累计 c n t cnt cnt数组的值

  2. 再将指针 i i i移动到下一次查询的左端点, 对于当前需要删除的数字 x x x, 如果出现次数大于 1 1 1, 那么 c n t ( x ) − 1 cnt(x) - 1 cnt(x)1, 不对答案产生影响, 否则答案 − 1 -1 1

因为每次移动指针最坏情况下是 O ( n ) O(n) O(n)次, 算法时间复杂度最坏 O ( q n ) O(qn) O(qn), 还是没办法通过

算法核心优化:因为算法瓶颈在指针会移动 O ( n q ) O(nq) O(nq)次数, 尝试使得右指针是单调的(不会向回移动), 左指针分块, 具体的来说

区间左端点按照分块的编号排序, 双关键字排序

  • 如果分块编号相同按照区间右端点从小到大排序
  • 如果分块编号不同, 块小的在前面

将所有查询分为 n \sqrt n n 块, 每一块长度是也是 n \sqrt n n , 块内部区间的右端点是递增的

对于右指针来说, 块内走的次数不会超过 n n n, 一共 n \sqrt n n 块, 最多移动 n n n \sqrt n nn

左指针分为两种情况

  • 块内最多移动 n \sqrt n n 次, 最多 q q q个询问, 算法时间复杂度最差 O ( q n ) O(q \sqrt n) O(qn )
  • 块间最多移动 2 n 2 \sqrt n 2n 次, 最多跨越 n − 1 \sqrt n - 1 n 1个块, 最差 O ( 2 n ) O(2n) O(2n)

因此左指针的最差时间复杂度是 O ( q n ) O(q \sqrt n) O(qn )

左右时间取最大值, 因此优化后的算法时间复杂度最坏情况下是 O ( q n ) O(q \sqrt n) O(qn )

假设块的大小是 a a a, 右指针的移动次数是 n 2 a \frac{n ^ 2}{a} an2, 左指针的最大移动次数是 m a ma ma, 也就是
a = n 2 m a = \sqrt {\frac{n ^ 2}{m}} a=mn2
左右指针移动的次数相当

如果 a = n a = \sqrt n a=n 比较慢, 尝试将 a a a变为 n 2 m \sqrt {\frac{n ^ 2}{m}} mn2

核心代码

    // i表示左端点, j表示右端点
    for (int k = 0, i = 1, j = 0, res = 0; k < m; ++k) {
   
        // l, r分别代表当前区间的左右端点
        auto [l, r, id] = q[k];
        while (j < r) add(w[++j], res);
        while (j > r) del(w[j--], res);
        while (i < l) del(w[i++], res);
        while (i > l) add(w[--i], res);

        ans[id] = res;
    }

示例代码

#include <bits/stdc++.h>

using namespace std;

const int N = 50010, M = 2e5 + 10, S = 1e6 + 10;

int n, m, len;
int w[N], cnt[S], ans[M];
struct Ask {
   
    int l, r, id;
} q[M];

int get(int i) {
   
    return i / len;
}

bool cmp(const Ask &a, const Ask &b) {
   
    int ba = get(a.l), bb = get(b.l);
    if (ba == bb) return a.r < b.r;
    return ba < bb;
}

void add(int x, int &ans) {
   
    if (!cnt[x]) ans++;
    cnt[x]++;
}

void del(int x, int &ans) {
   
    cnt[x]--;
    if (!cnt[x]) ans--;
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n;
    for (int i = 1; i <= n; ++i) cin >> w[i];
    cin >> m;

    // 计算块长
    len = sqrt(n * n / m + 1);

    for (int i = 0; i < m; ++i) {
   
        int l, r;
        cin >> l >> r;
        q[i] = {
   l, r, i};
    }

    sort(q, q + m, cmp);

    for (int k = 0, i = 1, j = 0, res = 0; k < m; ++k) {
   
        auto [l, r, id] = q[k];
        while (j < r) add(w[++j], res);
        while (j > r) del(w[j--], res);
        while (i < l) del(w[i++], res);
        while (i > l) add(w[--i], res);

        ans[id] = res;
    }

    for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
    return 0;
}

带修莫队

因为添加了修改操作, 因此需要加一个维度时间戳

因此询问需要多一个属性, 在第 k k k个修改之后, 第 k + 1 k + 1 k+1个修改之前, 对于当前操作的时间戳是 k k k, 因此询问的属性是三元组 ( l , r , k ) (l, r, k) (l,r,k)

指针 i , j i, j i,j的移动不会影响原序列的数值, 但是时间戳 t t t移动可能导致序列数值变化, 具体的来说

假设 t t t k k k变到了 k + 1 k + 1 k+1, 每个操作只会修改一个数

  • 假设修改的数字不在区间中, 对 c n t cnt cnt数组不产生影响
  • 假设修改的数字在区间中, 做如下操作, 首先在 c n t cnt cnt中减去原来的数, 然后再加上新的数, 操作时间复杂度是 O ( 1 ) O(1) O(1)

因此 t t t发生变化对时间复杂度不产生影响

假设时间 k k k数值是 x x x, 在 k + 1 k + 1 k+1的数值变为了 x ′ x' x, 如果需要从时间 k + 1 k + 1 k+1回到 k k k, 需要将数据恢复

具体的来说, 如果在 k + 1 k + 1 k+1时间将数值修改为 x ′ x' x, 从 k + 1 k + 1 k+1时间回到 k k k时间, swap()

算法核心

使得一个指针单调, 剩下的指针分块, 按照修改时间递增, 区间 l , r l, r l,r分块

设块的大小是 a a a, 块的数量是 n a \frac{n}{a} an

  • l l l指针, 块内移动数量是 m a ma ma, 块间移动是 2 a ⋅ n a 2a \cdot \frac{n}{a} 2aan, 移动次数 O ( m a + n ) O(ma + n) O(ma+n)
  • t t t指针, t t t递增, l l l n a \frac{n}{a} an块, r r r同理, 移动次数 O ( t ⋅ n a n a ) O(t \cdot \frac{n}{a} \frac{n}{a}) O(tanan)
  • l l l固定, r r r块内移动最差 m a ma ma, 块间移动是 O ( n ) O(n) O(n), 对于每个 l l l来说, r r r都是移动 O ( n ) O(n) O(n), 因此块间 O ( n 2 a ) O(\frac{n ^ 2}{a}) O(an2), 总的移动次数 O ( m a + n 2 a ) O(ma + \frac{n ^ 2}{a}) O(ma+an2)

块大小取 n t 3 \sqrt[3]{nt} 3nt

核心代码

bool cmp(const Ask &a, const Ask &b) {
   
    int al = get(a.l), ar = get(a.r);
    int bl = get(b.l), br = get(b.r);

    // 先按照区间的左端点所在的块编号排
    if (al != bl) return al < bl;
    // 如果左端点块相等, 按照右端点所在的块排
    if (ar != br) return ar < br;
    // 都相等说明两个区间在一个块内, 按照时间戳排
    return a.t < b.t;
}

int res = 0;
    for (int k = 1, i = 1, j = 0, t = 0; k <= qc; ++k) {
   
        int l = q[k].l, r = q[k].r, tm = q[k].t, id = q[k].id;
        while (j < r) add(w[++j], res);
        while (j > r) del(w[j--], res);
        while (i < l) del(w[i++], res);
        while (i > l) add(w[--i], res);

        //当前时间戳 < tm, 需要向上移动
        while (t < tm) {
   
            t++;
            // 如果修改位置落在了当前区间, 那么执行修改
            if (c[t].p >= i && c[t].p <= j) {
   
                del(w[c[t].p], res);
                add(c[t].c, res);
            }
            swap(w[c[t].p], c[t].c);
        }
        // 当前时间戳大于tm, 向下移动
        while (t > tm) {
   
            if (c[t].p >= i && c[t].p <= j) {
   
                del(w[c[t].p], res);
                add(c[t].c, res);
            }
            // 修改前的数值被交换到了c[t]中, 如果得到, 直接swap回来
            swap(w[c[t].p], c[t].c);
            t--;
        }
        ans[id] = res;
    }

示例代码

#include <bits/stdc++.h>

using namespace std;

const int N = 1e5 + 10, M = 1e5 + 10, S = 1e6 + 10;

int n, m, len, cc, qc;
int w[N], ans[M], cnt[S];
struct Ask {
   
    int l, r, id, t;
} q[M];
struct M {
   
    int p, c;
} c[M];

int get(int i) {
   
    return i / len;
}

bool cmp(const Ask &a, const Ask &b) {
   
    int al = get(a.l), ar = get(a.r);
    int bl = get(b.l), br = get(b.r);
    if (al != bl) return al < bl;
    if (ar != br) return ar < br;
    return a.t < b.t;
}

void add(int x, int &ans) {
   
    if (!cnt[x]) ans++;
    cnt[x]++;
}

void del(int x, int &ans) {
   
    cnt[x]--;
    if (!cnt[x]) ans--;
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n >> m;
    for (int i = 1; i <= n; ++i) cin >> w[i];

    for (int i = 0; i < m; ++i) {
   
        int a, b;
        char ch;
        cin >> ch >> a >> b;
        if (ch == 'Q') q[++qc] = {
   a, b, qc, cc};
        else c[++cc] = {
   a, b};
    }

    // 计算块的长度
    len = cbrt((double) n * max(1, cc)) + 1;

    // 如果查询区间左右端点在同一个块内, 按照时间戳顺序排序
    sort(q + 1, q + qc + 1, cmp);

    int res = 0;
    for (int k = 1, i = 1, j = 0, t = 0; k <= qc; ++k) {
   
        int l = q[k].l, r = q[k].r, tm = q[k].t, id = q[k].id;
        while (j < r) add(w[++j], res);
        while (j > r) del(w[j--], res);
        while (i < l) del(w[i++], res);
        while (i > l) add(w[--i], res);

        //当前时间戳 < tm, 需要向上移动
        while (t < tm) {
   
            t++;
            // 如果修改位置落在了当前区间, 那么执行修改
            if (c[t].p >= i && c[t].p <= j) {
   
                del(w[c[t].p], res);
                add(c[t].c, res);
            }
            swap(w[c[t].p], c[t].c);
        }
        // 当前时间戳大于tm, 向下移动
        while (t > tm) {
   
            if (c[t].p >= i && c[t].p <= j) {
   
                del(w[c[t].p], res);
                add(c[t].c, res);
            }
            // 修改前的数值被交换到了c[t]中, 如果得到, 直接swap回来
            swap(w[c[t].p], c[t].c);
            t--;
        }
        ans[id] = res;
    }

    for (int i = 1; i <= qc; ++i) cout << ans[i] << '\n';

    return 0;
}

回滚莫队

每次询问一个区间 [ l , r ] [l, r] [l,r], 很多种不同的数
, 对于每个数 x x x, 假设出现次数是 t t t, 每个数的重要度等于 t ⋅ x t \cdot x tx

希望求出区间内重要度的最大值

插入一个数求最大值直接取 m a x max max, 但是删除后很难维护最大值(可能需要开一个大根堆)

具体的算法步骤

(1) r r r单调, l l l分块进行排序

(2) 考虑块内如何计算

假设当前考虑的是块 1 1 1, 左端点一定在块 1 1 1

  • 假设右端点也在块 1 1 1, 直接暴力计算( O ( n ) O(\sqrt n) O(n ))所有询问
  • 剩下的询问, 右端点在后面

将区间分为两部分, 如果是右边部分, 因为 r r r是递增的, 因此直接插入

但是, c n t , r e s cnt, res cnt,res维护的不是整个查询区间, 而是从下一个块开始的信息, 因为左边的长度 ≤ n \le \sqrt n n , 因此左边部分暴力计算

左边暴力加之后, 再暴力回滚回去

注意数据范围很大, 需要开 l o n g    l o n g long \; long longlong, 并且需要离散化

核心代码

void solve() {
   
    for (int x = 0; x < m; ) {
   
        int y = x;
        // x, y之间的部分就是当前块的部分
        while (y < m && get(q[y].l) == get(q[x].l)) y++;
        // 当前块的尾端点(块的第一个位置 + 块长 - 1)
        int rb = get(q[x].l) * len + len - 1;

        // 暴力求查询区间右端点在块内
        while (x < y && q[x].r <= rb) {
   
            LL res = 0;
            auto [l, r, id] = q[x];
            for (int i = l; i <= r; ++i) add(w[i], res);
            ans[id] = res;
            // 每个询问的cnt数组产生的贡献不同需要将当前查询区间的影响减去
            for (int i = l; i <= r; ++i) cnt[w[i]]--;
            x++;
        }

        LL res = 0;
        int i = rb + 1, j = rb;
        while (x < y) {
   
            auto [l, r, id] = q[x];
            // 将区间右端点移动到r
            while (j < r) add(w[++j], res);
            LL backup = res;
            // 初始状态在块外部, 移动到块内的同时累计答案, 并且提前备份res, 计算好结果后恢复
            while (i > l) add(w[--i], res);
            ans[id] = res;
            // 将加入的数删除
            while (i < rb + 1) cnt[w[i++]]--;
            res = backup;
            x++;
        }
        memset(cnt, 0, sizeof cnt);
    }

    for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
}

示例代码

#include <bits/stdc++.h>

using namespace std;

typedef long long LL;
const int N = 1e5 + 10, M = 1e5 + 10;

int n, m, len;
int w[N], cnt[N];
LL ans[M];
vector<int> vec;
struct Ask {
   
    int l, r, id;
} q[M];

int get(int x) {
   
    return x / len;
}

int find(int x) {
   
    return lower_bound(vec.begin(), vec.end(), x) - vec.begin();
}

bool cmp(const Ask &a, const Ask &b) {
   
    int ba = get(a.l), bb = get(b.l);
    if (ba != bb) return ba < bb;
    return a.r < b.r;
}

void add(int x, LL &ans) {
   
    cnt[x]++;
    ans = max(ans, (LL) cnt[x] * vec[x]);
}

void solve() {
   
    for (int x = 0; x < m; ) {
   
        int y = x;
        // x, y之间的部分就是当前块的部分
        while (y < m && get(q[y].l) == get(q[x].l)) y++;
        // 当前块的尾端点(块的第一个位置 + 块长 - 1)
        int rb = get(q[x].l) * len + len - 1;

        // 暴力求查询区间右端点在块内
        while (x < y && q[x].r <= rb) {
   
            LL res = 0;
            auto [l, r, id] = q[x];
            for (int i = l; i <= r; ++i) add(w[i], res);
            ans[id] = res;
            // 每个询问的cnt数组产生的贡献不同需要将当前查询区间的影响减去
            for (int i = l; i <= r; ++i) cnt[w[i]]--;
            x++;
        }

        LL res = 0;
        int i = rb + 1, j = rb;
        while (x < y) {
   
            auto [l, r, id] = q[x];
            // 将区间右端点移动到r
            while (j < r) add(w[++j], res);
            LL backup = res;
            // 初始状态在块外部, 移动到块内的同时累计答案, 并且提前备份res, 计算好结果后恢复
            while (i > l) add(w[--i], res);
            ans[id] = res;
            // 将加入的数删除
            while (i < rb + 1) cnt[w[i++]]--;
            res = backup;
            x++;
        }
        memset(cnt, 0, sizeof cnt);
    }

    for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n >> m;
    len = sqrt(n);

    for (int i = 1; i <= n; ++i) {
   
        cin >> w[i];
        vec.push_back(w[i]);
    }

    // 原数组进行离散化
    sort(vec.begin(), vec.end());
    vec.erase(unique(vec.begin(), vec.end()), vec.end());

    for (int i = 1; i <= n; ++i) w[i] = find(w[i]);

    for (int i = 0; i < m; ++i) {
   
        int l, r;
        cin >> l >> r;
        q[i] = {
   l, r, i};
    }

    sort(q, q + m, cmp);
    solve();

    return 0;
}

树上莫队

u u u v v v上有多少个不同的

首先是暴力做法, 假设当当前询问的路径长度是 n n n, 统计不同数字出现的次数, 可以做到枚举一遍路径上的点就知道不同数字的个数(开一个, 记录每个数字出现的次数, 如果重复出现答案 − 1 -1 1), 一共 m m m个询问, 算法时间复杂度 O ( n m ) O(nm) O(nm), 会超时

尝试将树上问题变为区间问题, 也就是将树的序列取出, 这里用到的是欧拉序列

欧拉序列就是递归到当前点的时候记录一遍, 返回的时候再记录一遍

对于上图的欧拉序列是
1    2    2    3    5    5    6    6    7    7    3    4    8    8    4    1 1 \; 2\; 2\; 3\; 5\; 5\; 6\; 6\; 7\; 7\; 3\; 4\; 8\; 8\; 4\; 1 1223556677348841

观察树上路径 1 ⟶ 8 1 \longrightarrow 8 18

选取路径上 1 1 1最后出现的位置和 8 8 8最开始出现的位置, 发现是

1    2    2    3    5    5    6    6    7    7    3    4    8    1 \; 2\; 2\; 3\; 5\; 5\; 6\; 6\; 7\; 7\; 3\; 4\; 8\; 1223556677348

得到结论
欧拉路径上出现次数是 1 1 1的点是目标路径上的点, 因为访问过没返回

在欧拉序列中定义 f i f_i fi表示节点 i i i第一次出现的位置, e i e_i ei表示节点 i i i最后一次出现的位置

对于树上的任意节点 x , y x, y x,y, 假设 f x < f y f_x < f_y fx<fy

(1) 假设 l c a ( x , y ) = x lca(x, y) = x lca(x,y)=x, 那么在欧拉序列 [ f x , f y ] [f_x, f_y] [fx,fy]上出现一次的点的集合就是目标路径

(2) 假设 l c a ( x , y ) ≠ x lca(x, y) \ne x lca(x,y)=x, 那么在欧拉序列中 [ e x , f y ] [e_x, f_y] [ex,fy]上出现一次的点以及最近公共祖先 l c a ( x , y ) lca(x, y) lca(x,y), 上图中应该是 3 ′ 3' 3步的时候计算路径

这样就可以将树上询问变为区间询问


对于这样一段欧拉序列, 求的是只出现一次的数里面有多少个数是不同的

算法步骤

  • 对节点权值离散化
  • 求树的欧拉序列
  • l c a lca lca, d e p t h u depth_u depthu代表当前点的深度, f ( u , k ) f(u, k) f(u,k)表示从当前点 u u u向上走 2 k 2 ^ k 2k步的祖先, 可以通过递推计算, 先跳 2 k − 1 2 ^ {k - 1} 2k1再跳 2 k − 1 2 ^ {k - 1} 2k1步, 得到递推式 f ( u , k ) = f ( f ( u , k − 1 ) , k − 1 ) f(u, k) = f(f(u, k - 1), k - 1) f(u,k)=f(f(u,k1),k1)
    预处理这两个数组之后, 假设 d e p t h a < d e p t h b depth_a < depth_b deptha<depthb, 首先将 a a a跳到 b b b的高度, 然后同时向上跳
  • 将树的询问变为序列的询问
  • 莫队算法求区间不同数字的个数, c n t cnt cnt表示权值出现的次数, s t st st表示节点出现的次数

算法时间瓶颈在莫队, 算法时间复杂度 O ( n n ) O(n \sqrt n) O(nn )

核心代码

    int res = 0;
    for (int k = 0, i = 1, j = 0; k < m; ++k) {
   
        int l = q[k].l, r = q[k].r, id = q[k].id, p = q[k].p;
        while (j < r) add(seq[++j], res);
        while (j > r) add(seq[j--], res);
        while (i < l) add(seq[i++], res);
        while (i > l) add(seq[--i], res);
        if (p) add(p, res);
        ans[id] = res;
        if (p) add(p, res);
     }#include <bits/stdc++.h>

using namespace std;

const int N = 1e5 + 10, M = 1e5 + 10;

int n, m, len;
int w[N];
int head[N], ed[N], ne[N], idx;
int depth[N], f[N][17];
int seq[N], ptr, fst[N], lst[N];
int cnt[N], st[N], ans[M];
struct Ask {
   
    int l, r, id;
    int p;
} q[M];
vector<int> vec;
int que[N], h, t;

int find(int x) {
   
    return lower_bound(vec.begin(), vec.end(), x) - vec.begin();
}

void add_edge(int u, int v) {
   
    ed[idx] = v, ne[idx] = head[u], head[u] = idx++;
}

// 求欧拉序列
void dfs(int u, int fa) {
   
    seq[++ptr] = u;
    fst[u] = ptr;
    for (int i = head[u]; ~i; i = ne[i]) {
   
        int v = ed[i];
        if (v == fa) continue;
        dfs(v, u);
    }
    seq[++ptr] = u;
    lst[u] = ptr;
}

void bfs() {
   
    memset(depth, 0x3f, sizeof depth);
    h = 0, t = -1;
    depth[0] = 0, depth[1] = 1;
    que[++t] = 1;
    while (h <= t) {
   
        int u = que[h++];
        for (int i = head[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            if (depth[u] + 1 < depth[v]) {
   
                depth[v] = depth[u] + 1;
                f[v][0] = u;
                for (int k = 1; k <= 16; ++k) f[v][k] = f[f[v][k - 1]][k - 1];
                que[++t] = v;
            }
        }
    }
}

int lca(int a, int b) {
   
    if (depth[a] < depth[b]) swap(a, b);
    for (int k = 16; k >= 0; --k) {
   
        if (depth[f[a][k]] >= depth[b]) a = f[a][k];
    }
    if (a == b) return a;
    for (int k = 16; k >= 0; --k) {
   
        if (f[a][k] != f[b][k]) {
   
            a = f[a][k];
            b = f[b][k];
        }
    }
    return f[a][0];
}

int get(int i) {
   
    return i / len;
}

bool cmp(const Ask &a, const Ask &b) {
   
    int ba = get(a.l), bb = get(b.l);
    if (ba != bb) return ba < bb;
    return a.r < b.r;
}

void add(int x, int &ans) {
   
    st[x] ^= 1;
    // x节点出现了偶数次
    if (!st[x]) {
   
        cnt[w[x]]--;
        if (!cnt[w[x]]) ans--;
    }
    else {
   
        if (!cnt[w[x]]) ans++;
        cnt[w[x]]++;
    }
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0);

    memset(head, -1, sizeof head);
    cin >> n >> m;

    for (int i = 1; i <= n; ++i) cin >> w[i], vec.push_back(w[i]);
    sort(vec.begin(), vec.end());
    vec.erase(unique(vec.begin(), vec.end()), vec.end());
    for (int i = 1; i <= n; ++i) w[i] = find(w[i]);

    for (int i = 0; i < n - 1; ++i) {
   
        int a, b;
        cin >> a >> b;
        add_edge(a, b), add_edge(b, a);
    }

    // 处理欧拉序列
    dfs(1, -1);
    // 处理lca
    bfs();

    for (int i = 0; i < m; ++i) {
   
        int a, b;
        cin >> a >> b;
        if (fst[a] > fst[b]) swap(a, b);
        int p = lca(a, b);
        if (p == a) q[i] = {
   fst[a], fst[b], i, 0};
        else q[i] = {
   lst[a], fst[b], i, p};
    }

    len = sqrt(ptr);
    sort(q, q + m, cmp);

    int res = 0;
    for (int k = 0, i = 1, j = 0; k < m; ++k) {
   
        auto [l, r, id, p] = q[k];
        while (j < r) add(seq[++j], res);
        while (j > r) add(seq[j--], res);
        while (i < l) add(seq[i++], res);
        while (i > l) add(seq[--i], res);
        if (p) add(p, res);
        ans[id] = res;
        if (p) add(p, res);
     }

    for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
    return 0;
}

示例代码

#include <bits/stdc++.h>

using namespace std;

const int N = 1e5 + 10, M = 1e5 + 10;

int n, m, len;
int w[N];
int head[N], ed[N], ne[N], idx;
int depth[N], f[N][17];
int seq[N], ptr, fst[N], lst[N];
int cnt[N], st[N], ans[M];
struct Ask {
   
    int l, r, id;
    int p;
} q[M];
vector<int> vec;
int que[N], h, t;

int find(int x) {
   
    return lower_bound(vec.begin(), vec.end(), x) - vec.begin();
}

void add_edge(int u, int v) {
   
    ed[idx] = v, ne[idx] = head[u], head[u] = idx++;
}

// 求欧拉序列
void dfs(int u, int fa) {
   
    seq[++ptr] = u;
    fst[u] = ptr;
    for (int i = head[u]; ~i; i = ne[i]) {
   
        int v = ed[i];
        if (v == fa) continue;
        dfs(v, u);
    }
    seq[++ptr] = u;
    lst[u] = ptr;
}

void bfs() {
   
    memset(depth, 0x3f, sizeof depth);
    h = 0, t = -1;
    depth[0] = 0, depth[1] = 1;
    que[++t] = 1;
    while (h <= t) {
   
        int u = que[h++];
        for (int i = head[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            if (depth[u] + 1 < depth[v]) {
   
                depth[v] = depth[u] + 1;
                f[v][0] = u;
                for (int k = 1; k <= 16; ++k) f[v][k] = f[f[v][k - 1]][k - 1];
                que[++t] = v;
            }
        }
    }
}

int lca(int a, int b) {
   
    if (depth[a] < depth[b]) swap(a, b);
    for (int k = 16; k >= 0; --k) {
   
        if (depth[f[a][k]] >= depth[b]) a = f[a][k];
    }
    if (a == b) return a;
    for (int k = 16; k >= 0; --k) {
   
        if (f[a][k] != f[b][k]) {
   
            a = f[a][k];
            b = f[b][k];
        }
    }
    return f[a][0];
}

int get(int i) {
   
    return i / len;
}

bool cmp(const Ask &a, const Ask &b) {
   
    int ba = get(a.l), bb = get(b.l);
    if (ba != bb) return ba < bb;
    return a.r < b.r;
}

void add(int x, int &ans) {
   
    st[x] ^= 1;
    // x节点出现了偶数次
    if (!st[x]) {
   
        cnt[w[x]]--;
        if (!cnt[w[x]]) ans--;
    }
    else {
   
        if (!cnt[w[x]]) ans++;
        cnt[w[x]]++;
    }
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0);

    memset(head, -1, sizeof head);
    cin >> n >> m;

    for (int i = 1; i <= n; ++i) cin >> w[i], vec.push_back(w[i]);
    sort(vec.begin(), vec.end());
    vec.erase(unique(vec.begin(), vec.end()), vec.end());
    for (int i = 1; i <= n; ++i) w[i] = find(w[i]);

    for (int i = 0; i < n - 1; ++i) {
   
        int a, b;
        cin >> a >> b;
        add_edge(a, b), add_edge(b, a);
    }

    // 处理欧拉序列
    dfs(1, -1);
    // 处理lca
    bfs();

    for (int i = 0; i < m; ++i) {
   
        int a, b;
        cin >> a >> b;
        if (fst[a] > fst[b]) swap(a, b);
        int p = lca(a, b);
        if (p == a) q[i] = {
   fst[a], fst[b], i, 0};
        else q[i] = {
   lst[a], fst[b], i, p};
    }

    len = sqrt(ptr);
    sort(q, q + m, cmp);

    int res = 0;
    for (int k = 0, i = 1, j = 0; k < m; ++k) {
   
        int l = q[k].l, r = q[k].r, id = q[k].id, p = q[k].p;
        while (j < r) add(seq[++j], res);
        while (j > r) add(seq[j--], res);
        while (i < l) add(seq[i++], res);
        while (i > l) add(seq[--i], res);
        if (p) add(p, res);
        ans[id] = res;
        if (p) add(p, res);
     }

    for (int i = 0; i < m; ++i) cout << ans[i] << '\n';
    return 0;
}

二次离线莫队(Mo’s Algorithm, Secondary Offline)

如果直接用基础莫队每次 a d d ( x ) add(x) add(x), 需要遍历整个区间, 算法时间复杂度会退化到 O ( n 2 ) O(n ^ 2) O(n2)

算法精髓: 莫队移动指针时那些难以计算的 O ( n ) O(n) O(n)级别的修改操作,全部“打包”存起来,等到最后统一处理

(1) 第一步操作
要求 a i ⊕ a j a_i \oplus a_j aiaj k k k 1 1 1

不妨设 a i ⊕ A = Y a_i \oplus A = Y aiA=Y, 并且 Y Y Y的二进制表示下有 k k k 1 1 1, 由数学变换, A = Y ⊕ a i A = Y \oplus a_i A=Yai
也就是对于一个数 a i a_i ai, 所有的 A = Y ⊕ a i A = Y \oplus a_i A=Yai这样的数字都是与其配对的结果

代码表示为

vector<int> nums;
for (int i = 0; i < 1 << 14; i++) // 枚举所有可能的异或结果Y
    if (get_count(i) == k) 
        nums.push_back(i); // nums存下了所有二进制中包含k个1的数字

(2)第二步操作-外层套基础莫队, 对查询进行排序

(3)第三步操作-利用前缀和思想

当我们加入一个新元素 w i w_i wi 时,我们需要知道区间 [ l , r ] [l, r] [l,r] 中有多少个数和它配对

C n t ( [ l , r ] , w i ) = C n t ( [ 1 , r ] , w i ) − C n t ( [ 1 , l − 1 ] , w i ) Cnt([l, r], w_i) = Cnt([1, r], w_i) - Cnt([1, l - 1], w_i) Cnt([l,r],wi)=Cnt([1,r],wi)Cnt([1,l1],wi)

也就是说,我们可以预先算前缀中与某个数配对的数量,然后用减法得到区间结果

代码表示为

for (int i = 1; i <= n; i++) {
   
    for (auto y: nums) ++g[w[i] ^ y]; // 1. 统计:w[i]的出现,会让数值为 (w[i]^y) 的需求+1
    f[i] = g[w[i + 1]]; // 2. 记录:f[i] 表示在处理完前i个数后,有多少数能和 w[i+1] 配对
}

f i f_i fi 的含义非常关键:它表示在序列的前 i i i 个位置中,有多少个数 w j w_j wj满足 w j ⊕ w i + 1 w_j \oplus w_{i + 1} wjwi+1的结果有 k k k 1 1 1

(4) 第四步操作-莫队移动指针算简单的任务, 记录复杂的任务

右指针向右移动为例

  • 因为 f R f_R fR 算的是 [ 1 , R ] [1, R] [1,R] 的贡献,而我们实际需要的是 [ L , R ] [L, R] [L,R] 的贡献

  • 多算的部分是 [ 1 , L − 1 ] [1, L-1] [1,L1] 这一段。这段的贡献我们暂时不减,而是把它记在一个任务列表 range[L-1]里(因为难以计算)

(5)二次离线处理(Secondary Offline)

代码表示为

memset(g, 0, sizeof g);
for (int i = 1; i <= n; i++) {
   
    // 1. 更新前缀统计:把位置 i 的数 w[i] 加入统计
    for (auto y: nums) ++g[w[i] ^ y]; 
    
    // 2. 处理任务:检查有没有人需要 [1, i] 这个前缀的信息
    for (auto &rg: range[i]) {
    
        int id = rg.id, l = rg.l, r = rg.r, t = rg.t;
        for (int x = l; x <= r; x++)
            q[id].res += g[w[x]] * t; // 查询 w[x] 在当前前缀 [1,i] 中的配对数量
    }
}

这里 g [ x ] g[x] g[x] 的含义是:在当前扫描到的前缀中,有多少个数与数值 x x x 异或后结果有 k k k 1 1 1

当我们扫描到位置 i i i 时,所有挂起的任务 r a n g e [ i ] range[i] range[i](即需要前缀 [ 1 , i ] [1, i] [1,i] 信息的任务)就可以被执行了

预处理时间复杂度 O ( C 14 7 ⋅ n ) O(C_{14} ^ 7 \cdot n) O(C147n), 莫队时间复杂度 O ( n n ) O(n\sqrt n) O(nn ), 算法时间复杂度 O ( n ⋅ C 14 7 + n n ) O(n\cdot C_{14} ^ 7 + n \sqrt n) O(nC147+nn )

核心代码

    for (int t = 0, i = 1, j = 0; t < m; ++t) {
   
        int l = q[t].l, r = q[t].r;

        // j指针向右扩张, tsk[i - 1]记录未来[j + 1, r]与[1, i - 1]造成的影响, 也就是配对的数量
        if (j < r) tsk[i - 1].push_back({
   j + 1, r, t, -1});
        while (j < r) q[t].ans += f[j++];

        // j指针向左收缩, tsk[i - 1]记录未来[r + 1, j]与[1, i - 1]造成的影响, 也就是配对的数量
        if (j > r) tsk[i - 1].push_back({
   r + 1, j, t, 1});
        while (j > r) q[t].ans -= f[--j];

        // i指针向右收缩, 对于当前移除的元素w[i], 希望计算对区间[i + 1, j]的影响, w[i]造成的影响值是f[i - 1]
        if (i < l) tsk[j].push_back({
   i, l - 1, t, -1});
        while (i < l) q[t].ans += f[i - 1] + !k, i++;

        // i指针向左扩张, 相当于添加当前元素w[i - 1], 希望计算对区间[i, j]的影响, w[i - 1]造成的影响是f[i - 2]
        if (i > l) tsk[j].push_back({
   l, i - 1, t, 1});
        while (i > l) q[t].ans -= f[i - 2] + !k, i--;
    }

    memset(g, 0, sizeof g);
    for (int i = 1; i <= n; ++i) {
   
        for (auto y : vec) g[w[i] ^ y]++;
        for (auto &[l, r, id, t] : tsk[i]) {
   
            for (int x = l; x <= r; ++x) q[id].ans += g[w[x]] * t;
        }
    }

如何理解 i i i指针?

f [ x ] f[x] f[x]的定义是 w [ x + 1 ] w[x + 1] w[x+1]与前缀 [ 1 , x ] [1, x] [1,x]产生的配对数量

w [ i ] w[i] w[i]产生的影响是 f [ i − 1 ] f[i - 1] f[i1], w [ i − 1 ] w[i - 1] w[i1]产生的影响是 f [ i − 2 ] f[i - 2] f[i2]

示例代码

#include <bits/stdc++.h>

using namespace std;

typedef long long LL;
const int N = 1e5 + 10, M = 1e5 + 10;

int n, m, k, len;
int w[N];
LL ans[N];

struct Ask {
   
    int l, r, id;
    LL ans;
} q[M];

struct Tsk {
   
    int l, r, id, t;
};

vector<Tsk> tsk[N];
int f[N], g[N];

inline int get_cnt(int x) {
   
    int ans = 0;
    while (x) ans += x & 1, x >>= 1;
    return ans;
}

inline int get(int i) {
   
    return i / len;
}

bool cmp(const Ask &a, const Ask &b) {
   
    int ba = get(a.l), bb = get(b.l);
    if (ba != bb) return ba < bb;
    return a.r < b.r;
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n >> m >> k;
    for (int i = 1; i <= n; ++i) cin >> w[i];

    vector<int> vec;
    for (int i = 0; i < 1 << 14; ++i) {
   
        if (get_cnt(i) == k) vec.push_back(i);
    }

    for (int i = 1; i <= n; ++i) {
   
        for (int x : vec) g[w[i] ^ x]++;
        f[i] = g[w[i + 1]];
    }

    for (int i = 0; i < m; ++i) {
   
        int l, r;
        cin >> l >> r;
        q[i] = {
   l, r, i, 0};
    }

    len = sqrt(n);
    sort(q, q + m, cmp);

    for (int idx = 0, i = 1, j = 0; idx < m; ++idx) {
   
        int l = q[idx].l, r = q[idx].r;

        if (j < r) tsk[i - 1].push_back({
   j + 1, r, idx, -1});
        while (j < r) q[idx].ans += f[j++];

        if (j > r) tsk[i - 1].push_back({
   r + 1, j, idx, 1});
        while (j > r) q[idx].ans -= f[--j];

        if (i < l) tsk[j].push_back({
   i, l - 1, idx, -1});
        while (i < l) q[idx].ans += f[i - 1] + !k, i++;

        if (i > l) tsk[j].push_back({
   l, i - 1, idx, 1});
        while (i > l) q[idx].ans -= f[i - 2] + !k, i--;
    }

    memset(g, 0, sizeof g);
    for (int i = 1; i <= n; ++i) {
   
        for (int x : vec) g[w[i] ^ x]++;
        for (auto &[l, r, id, t] : tsk[i]) {
   
            for (int x = l; x <= r; ++x) {
   
                q[id].ans += g[w[x]] * t;
            }
        }
    }

    for (int i = 1; i < m; i++) q[i].ans += q[i - 1].ans;
    for (int i = 0; i < m; i++) ans[q[i].id] = q[i].ans;
    for (int i = 0; i < m; i++) cout << ans[i] << '\n';

    return 0;
}