2025牛客暑期多校第三场 H 题——数据结构碾压法

主播主播,这题正解是倍增+并查集按时间顺序维护可达连通块,这样的方法还是太吃操作了,有没有其它简单强势的方法呢?

有的,兄弟,有的,我们可以使用数据结构优化朴素DP转移通过本题。

题意

给定一棵 个结点的树。有一枚棋子初始位于 号结点。

按序给定 个两两不交的时间段,第 个时间段 中,树上会出现一个目标

每一时刻,若目标存在,棋子与目标位置不同且连通,则棋子向目标移动一步,否则棋子不动。若此时二者位置相同,则称该时刻二者重合。

你可以在任意时刻切断树上任意数量的边,每条边被切断后不会恢复。

判断棋子能否与目标重合,若能则最小化重合发生的时刻。

数据范围:

思路

任意指定一个目标结点,显然在任意时刻棋子要么向目标结点(朝叶子方向)移动,要么不移动,直到与目标结点重合。

由于要最小化时间且目标结点固定,贪心地尽可能早地向目标结点方向移动是更优的,我们总是可以通过删除无关的边使得棋子朝想要的方向移动,不朝反方向移动(过河拆桥)。

换句话说,我们可以考虑 表示棋子从根节点出发到达结点 的最早时刻 ,对每个目标结点出现在 的时刻 ,若 是否小于等于 ,则 是一个备选答案。

考虑如何转移,即已知棋子最早在 时刻到达结点 时,求最早何时可以移动到其儿子

思考发现:如果存在最早的大于 的时刻 ,使得此时刻存在某个目标出现在子树 ,则最早可以在时刻 到达结点 ,即

容易得到,初始值为 ,其中

主席树+二分

上述方法的难点在于如何快速找到大于时刻 ,且一个目标节点出现在子树 中的最早时刻。

假如:我们有结点 中所有目标出现“时间段数组”,且按出现时间排序,我们可以直接在这个数组上二分,在 时间内得到我们所需要的时间段及所需时刻

可是,如何快速得到子树 对应的“时间段数组”?

我们考虑将所有时间段按目标节点 DFS 序排序,则对任意一个子树中的所有时间段均是在此序列中的一个连续区间,使用 主席树 可以得到区间中(即子树中)第 个时间段,即发挥了上面中的“时间段数组”的作用。

上述解法时间复杂度 ,空间复杂度

提交记录,注意到出题人使用了 的毒瘤数据,喜提TLE

可持久化权值线段树+标记永久化+线段树上二分

主播主播,我们干嘛要费劲巴拉地模拟一个“时间段数组”,直接存储某个时间是否有目标节点出现不就好了吗?

对吗?不对不对。哦!对的对的,我们可以以时刻作为下标, 表示时刻 是否有目标节点出现。

由于我们要快速得到子树 对应的数组 及其区间和(表示某个时间段是否有节点出现,用于二分),我们使用可持久化权值线段树维护该数组。

还是将所有节点按 DFS 序排序,发现节点所属的某个时间段 的作用是将 ,这是一个区间加操作,使用标记永久化的技巧实现。

然后,我们就可以超快速地得到子树 对应的线段树(线段树差分)了,直接在该线段树上二分,即可在 时间内得到所需时刻

于是我们就得到了时间复杂度 ,空间复杂度 的解法,好像有戏?

提交记录,注意到出题人使用了 的毒瘤数据,且我们只有 512MB 空间,喜提MLE

AC 解法——小波矩阵+二分

什么?你说上面的两个方法都会爆空间,那可不就炸缸了?主播主播,有没有更强力的数据结构?

有的有的!接下来隆重登场的是是Wavelet Matrix(小波矩阵)

  • 区间第k小

  • 区间某个数出现的频率

  • 区间小于等于某个数的个数

以上操作都是 的时间复杂度,且常数很小,更关键的是其空间复杂度为 ,在常见的值域 内可视为线性。

爆论:小波矩阵是主席树的完美上位替代

我们将上文解法一简单修改后即得到了小波矩阵+二分的题解,时间复杂度 ,空间复杂度

提交记录,在使用快读后,极限通过,毛毛虫火箭终于取得了胜利,以下是 AC 代码:

#include <bits/stdc++.h>

using i64 = long long;
using u64 = unsigned long long;

using namespace std;

constexpr int inf = 1e9;
constexpr i64 infl = 1e18;
constexpr i64 p = 1e9 + 7;

constexpr int maxv = 1e9;

int read() {
    int sum = 0, fl = 1;
    int ch = getchar();
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') fl = -1;
    for (; isdigit(ch); ch = getchar()) sum = sum * 10 + ch - '0';
    return sum * fl;
}

struct BitRank {
    vector<u64> B;
    vector<int> cnt;

    void resize(int num) {
        B.resize(((num + 1) >> 6) + 1, 0);
        cnt.resize(B.size(), 0);
    }

    void set(int i, u64 val) { B[i >> 6] |= val << (i & 63); }

    void build() {
        for (int i = 1; i < B.size(); i++) {
            cnt[i] = cnt[i - 1] + __builtin_popcountll(B[i - 1]);
        }
    }

    int rank1(int i) {
        return cnt[i >> 6] + __builtin_popcountll(B[i >> 6] & ((1ull << (i & 63)) - 1ull));
    }

    int rank1(int i, int j) { return rank1(j) - rank1(i); }

    int rank0(int i) { return i - rank1(i); }

    int rank0(int i, int j) { return rank0(j) - rank0(i); }
};

struct WaveletMatrix {
    int ht;
    vector<int> pos;
    vector<BitRank> bk;

    explicit WaveletMatrix(vector<int> vec) {
        int sig = *max_element(vec.begin(), vec.end());
        ht = sig == 0 ? 1 : 64 - __builtin_clzll(sig);
        pos.resize(ht), bk.resize(ht);
        for (int i = 0; i < ht; ++i) {
            bk[i].resize(vec.size());
            for (int j = 0; j < vec.size(); ++j) {
                bk[i].set(j, vec[j] >> (ht - i - 1) & 1);
            }
            bk[i].build();
            auto it = stable_partition(vec.begin(), vec.end(), [&](int c) {
                return ~c >> (ht - i - 1) & 1;
            });
            pos[i] = it - vec.begin();
        }
    }

    // [l, r) 中val出现的频率
    int rank(int val, int l, int r) {
        return rank(val, r) - rank(val, l);
    }

    // [0, i) 中val出现的频率
    int rank(int val, int i) {
        int p = 0;
        for (int j = 0; j < ht; ++j) {
            if (val >> (ht - j - 1) & 1) {
                p = pos[j] + bk[j].rank1(p);
                i = pos[j] + bk[j].rank1(i);
            } else {
                p = bk[j].rank0(p);
                i = bk[j].rank0(i);
            }
        }
        return i - p;
    }

    // [l, r) 中第 k 小
    int quantile(int k, int l, int r) {
        int res = 0;
        for (int i = 0; i < ht; ++i) {
            int j = bk[i].rank0(l, r);
            if (j > k) {
                l = bk[i].rank0(l);
                r = bk[i].rank0(r);
            } else {
                l = pos[i] + bk[i].rank1(l);
                r = pos[i] + bk[i].rank1(r);
                k -= j;
                res |= 1 << (ht - i - 1);
            }
        }
        return res;
    }

    int rangefreq(int i, int j, int a, int b, int l, int r, int x) {
        if (i == j || r <= a || b <= l) return 0;
        int mid = (l + r) >> 1;
        if (a <= l && r <= b) {
            return j - i;
        }
        int left = rangefreq(bk[x].rank0(i), bk[x].rank0(j), a, b, l, mid, x + 1);
        int right = rangefreq(pos[x] + bk[x].rank1(i), pos[x] + bk[x].rank1(j), a, b, mid, r, x + 1);
        return left + right;
    }

    // [l, r) 在[a, b) 值域的数字个数
    int rangefreq(int l, int r, int a, int b) {
        return rangefreq(l, r, a, b, 0, 1 << ht, 0);
    }

    int rangemin(int i, int j, int a, int b, int l, int r, int x, int val) {
        if (i == j || r <= a || b <= l) return -1;
        if (r - l == 1) return val;
        int mid = (l + r) >> 1;
        int res = rangemin(bk[x].rank0(i), bk[x].rank0(j), a, b, l, mid, x + 1, val);
        if (res < 0)
            return rangemin(pos[x] + bk[x].rank1(i), pos[x] + bk[x].rank1(j),
                            a, b, mid, r, x + 1, val + (1 << (ht - x - 1)));
        return res;
    }

    // [l, r) 在[a, b) 值域内存在的最小值是什么,不存在返回-1
    int rangemin(int l, int r, int a, int b) {
        return rangemin(l, r, a, b, 0, 1 << ht, 0, 0);
    }
};

void solve() {
    int n = read(), k = read();
    vector<int> fa(n + 1);
    vector<vector<int> > adj(n + 1);
    for (int i = 2; i <= n; i++) {
        fa[i] = read();
        adj[fa[i]].push_back(i);
    }

    vector<tuple<int, int, int> > op(k);
    vector<vector<tuple<int, int, int> > > uop(n + 1);
    for (int i = 0; i < k; i++) {
        auto &[u, l, r] = op[i];
        u = read(), l = read(), r = read();
        uop[u].emplace_back(l, r, i);
    }

    int id = 0;
    vector<int> dfn(n + 1), rnk(n + 1), btn(n + 1);
    auto dfs1 = [&](auto &&self, int u) -> int {
        id++;
        dfn[u] = id, rnk[id] = u;
        if (adj[u].empty()) return btn[u] = u;
        for (auto v: adj[u]) btn[u] = self(self, v);
        return btn[u];
    };
    dfs1(dfs1, 1);

    vector<int> rt(n + 2), ids;
    for (int i = 1; i <= n; i++) {
        rt[i] = ids.size();
        for (auto [l, r, id]: uop[rnk[i]]) {
            ids.push_back(id);
        }
    }
    rt[n + 1] = ids.size();
    WaveletMatrix wm(ids);

    int ans = maxv + 1;
    vector<int> dp(n + 1, maxv + 1);
    dp[1] = 0;
    auto dfs2 = [&](auto &&self, int u) -> void {
        for (auto [l, r, id]: uop[u]) {
            if (r < dp[u]) continue;
            ans = min(ans, max(l, dp[u]));
            break;
        }

        for (auto v: adj[u]) {
            int cnt = rt[dfn[btn[v]] + 1] - rt[dfn[v]];
            int cl = 0, cr = cnt, idx = k;
            while (cl < cr) {
                int cm = cl + cr >> 1;
                int nidx = wm.quantile(cm, rt[dfn[v]], rt[dfn[btn[v]] + 1]);
                auto &[ou, l, r] = op[nidx];
                if (r >= dp[u] + 1) {
                    idx = nidx;
                    cr = cm;
                } else {
                    cl = cm + 1;
                }
            }

            if (idx == k) continue;
            auto &[ou, ll, rr] = op[idx];
            dp[v] = min(max(ll, dp[u] + 1), maxv + 1);
            self(self, v);
        }
    };
    dfs2(dfs2, 1);

    if (ans > maxv) cout << -1 << '\n';
    else cout << ans << '\n';
}

int main() {
    // cin.tie(nullptr), ios::sync_with_stdio(false);

    int t = 1;
    // cin >> t;
    while (t--) {
        solve();
    }

    return 0;
}