2025牛客暑期多校第三场 H 题——数据结构碾压法
主播主播,这题正解是倍增+并查集按时间顺序维护可达连通块,这样的方法还是太吃操作了,有没有其它简单强势的方法呢?
有的,兄弟,有的,我们可以使用数据结构优化朴素DP转移通过本题。
题意
给定一棵 个结点的树。有一枚棋子初始位于
号结点。
按序给定 个两两不交的时间段,第
个时间段
中,树上会出现一个目标
。
每一时刻,若目标存在,棋子与目标位置不同且连通,则棋子向目标移动一步,否则棋子不动。若此时二者位置相同,则称该时刻二者重合。
你可以在任意时刻切断树上任意数量的边,每条边被切断后不会恢复。
判断棋子能否与目标重合,若能则最小化重合发生的时刻。
数据范围:
思路
任意指定一个目标结点,显然在任意时刻棋子要么向目标结点(朝叶子方向)移动,要么不移动,直到与目标结点重合。
由于要最小化时间且目标结点固定,贪心地尽可能早地向目标结点方向移动是更优的,我们总是可以通过删除无关的边使得棋子朝想要的方向移动,不朝反方向移动(过河拆桥)。
换句话说,我们可以考虑 表示棋子从根节点出发到达结点
的最早时刻 ,对每个目标结点出现在
的时刻
,若
是否小于等于
,则
是一个备选答案。
考虑如何转移,即已知棋子最早在 时刻到达结点
时,求最早何时可以移动到其儿子
。
思考发现:如果存在最早的大于 的时刻
,使得此时刻存在某个目标出现在子树
中,则最早可以在时刻
到达结点
,即
。
容易得到,初始值为 ,其中
。
主席树+二分
上述方法的难点在于如何快速找到大于时刻 ,且一个目标节点出现在子树
中的最早时刻。
假如:我们有结点 中所有目标出现“时间段数组”,且按出现时间排序,我们可以直接在这个数组上二分,在
时间内得到我们所需要的时间段及所需时刻
。
可是,如何快速得到子树 对应的“时间段数组”?
我们考虑将所有时间段按目标节点 DFS 序排序,则对任意一个子树中的所有时间段均是在此序列中的一个连续区间,使用 主席树 可以得到区间中(即子树中)第 个时间段,即发挥了上面中的“时间段数组”的作用。
上述解法时间复杂度 ,空间复杂度
。
提交记录,注意到出题人使用了 的毒瘤数据,喜提TLE。
可持久化权值线段树+标记永久化+线段树上二分
主播主播,我们干嘛要费劲巴拉地模拟一个“时间段数组”,直接存储某个时间是否有目标节点出现不就好了吗?
对吗?不对不对。哦!对的对的,我们可以以时刻作为下标, 表示时刻
是否有目标节点出现。
由于我们要快速得到子树 对应的数组
及其区间和(表示某个时间段是否有节点出现,用于二分),我们使用可持久化权值线段树维护该数组。
还是将所有节点按 DFS 序排序,发现节点所属的某个时间段 的作用是将
加
,这是一个区间加操作,使用标记永久化的技巧实现。
然后,我们就可以超快速地得到子树 对应的线段树(线段树差分)了,直接在该线段树上二分,即可在
时间内得到所需时刻
。
于是我们就得到了时间复杂度 ,空间复杂度
的解法,好像有戏?
提交记录,注意到出题人使用了 的毒瘤数据,且我们只有 512MB 空间,喜提MLE。
AC 解法——小波矩阵+二分
什么?你说上面的两个方法都会爆空间,那可不就炸缸了?主播主播,有没有更强力的数据结构?
有的有的!接下来隆重登场的是是Wavelet Matrix(小波矩阵)!
-
区间第k小
-
区间某个数出现的频率
-
区间小于等于某个数的个数
以上操作都是 的时间复杂度,且常数很小,更关键的是其空间复杂度为
,在常见的值域
内可视为线性。
爆论:小波矩阵是主席树的完美上位替代。
我们将上文解法一简单修改后即得到了小波矩阵+二分的题解,时间复杂度 ,空间复杂度
。
提交记录,在使用快读后,极限通过,毛毛虫火箭终于取得了胜利,以下是 AC 代码:
#include <bits/stdc++.h>
using i64 = long long;
using u64 = unsigned long long;
using namespace std;
constexpr int inf = 1e9;
constexpr i64 infl = 1e18;
constexpr i64 p = 1e9 + 7;
constexpr int maxv = 1e9;
int read() {
int sum = 0, fl = 1;
int ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') fl = -1;
for (; isdigit(ch); ch = getchar()) sum = sum * 10 + ch - '0';
return sum * fl;
}
struct BitRank {
vector<u64> B;
vector<int> cnt;
void resize(int num) {
B.resize(((num + 1) >> 6) + 1, 0);
cnt.resize(B.size(), 0);
}
void set(int i, u64 val) { B[i >> 6] |= val << (i & 63); }
void build() {
for (int i = 1; i < B.size(); i++) {
cnt[i] = cnt[i - 1] + __builtin_popcountll(B[i - 1]);
}
}
int rank1(int i) {
return cnt[i >> 6] + __builtin_popcountll(B[i >> 6] & ((1ull << (i & 63)) - 1ull));
}
int rank1(int i, int j) { return rank1(j) - rank1(i); }
int rank0(int i) { return i - rank1(i); }
int rank0(int i, int j) { return rank0(j) - rank0(i); }
};
struct WaveletMatrix {
int ht;
vector<int> pos;
vector<BitRank> bk;
explicit WaveletMatrix(vector<int> vec) {
int sig = *max_element(vec.begin(), vec.end());
ht = sig == 0 ? 1 : 64 - __builtin_clzll(sig);
pos.resize(ht), bk.resize(ht);
for (int i = 0; i < ht; ++i) {
bk[i].resize(vec.size());
for (int j = 0; j < vec.size(); ++j) {
bk[i].set(j, vec[j] >> (ht - i - 1) & 1);
}
bk[i].build();
auto it = stable_partition(vec.begin(), vec.end(), [&](int c) {
return ~c >> (ht - i - 1) & 1;
});
pos[i] = it - vec.begin();
}
}
// [l, r) 中val出现的频率
int rank(int val, int l, int r) {
return rank(val, r) - rank(val, l);
}
// [0, i) 中val出现的频率
int rank(int val, int i) {
int p = 0;
for (int j = 0; j < ht; ++j) {
if (val >> (ht - j - 1) & 1) {
p = pos[j] + bk[j].rank1(p);
i = pos[j] + bk[j].rank1(i);
} else {
p = bk[j].rank0(p);
i = bk[j].rank0(i);
}
}
return i - p;
}
// [l, r) 中第 k 小
int quantile(int k, int l, int r) {
int res = 0;
for (int i = 0; i < ht; ++i) {
int j = bk[i].rank0(l, r);
if (j > k) {
l = bk[i].rank0(l);
r = bk[i].rank0(r);
} else {
l = pos[i] + bk[i].rank1(l);
r = pos[i] + bk[i].rank1(r);
k -= j;
res |= 1 << (ht - i - 1);
}
}
return res;
}
int rangefreq(int i, int j, int a, int b, int l, int r, int x) {
if (i == j || r <= a || b <= l) return 0;
int mid = (l + r) >> 1;
if (a <= l && r <= b) {
return j - i;
}
int left = rangefreq(bk[x].rank0(i), bk[x].rank0(j), a, b, l, mid, x + 1);
int right = rangefreq(pos[x] + bk[x].rank1(i), pos[x] + bk[x].rank1(j), a, b, mid, r, x + 1);
return left + right;
}
// [l, r) 在[a, b) 值域的数字个数
int rangefreq(int l, int r, int a, int b) {
return rangefreq(l, r, a, b, 0, 1 << ht, 0);
}
int rangemin(int i, int j, int a, int b, int l, int r, int x, int val) {
if (i == j || r <= a || b <= l) return -1;
if (r - l == 1) return val;
int mid = (l + r) >> 1;
int res = rangemin(bk[x].rank0(i), bk[x].rank0(j), a, b, l, mid, x + 1, val);
if (res < 0)
return rangemin(pos[x] + bk[x].rank1(i), pos[x] + bk[x].rank1(j),
a, b, mid, r, x + 1, val + (1 << (ht - x - 1)));
return res;
}
// [l, r) 在[a, b) 值域内存在的最小值是什么,不存在返回-1
int rangemin(int l, int r, int a, int b) {
return rangemin(l, r, a, b, 0, 1 << ht, 0, 0);
}
};
void solve() {
int n = read(), k = read();
vector<int> fa(n + 1);
vector<vector<int> > adj(n + 1);
for (int i = 2; i <= n; i++) {
fa[i] = read();
adj[fa[i]].push_back(i);
}
vector<tuple<int, int, int> > op(k);
vector<vector<tuple<int, int, int> > > uop(n + 1);
for (int i = 0; i < k; i++) {
auto &[u, l, r] = op[i];
u = read(), l = read(), r = read();
uop[u].emplace_back(l, r, i);
}
int id = 0;
vector<int> dfn(n + 1), rnk(n + 1), btn(n + 1);
auto dfs1 = [&](auto &&self, int u) -> int {
id++;
dfn[u] = id, rnk[id] = u;
if (adj[u].empty()) return btn[u] = u;
for (auto v: adj[u]) btn[u] = self(self, v);
return btn[u];
};
dfs1(dfs1, 1);
vector<int> rt(n + 2), ids;
for (int i = 1; i <= n; i++) {
rt[i] = ids.size();
for (auto [l, r, id]: uop[rnk[i]]) {
ids.push_back(id);
}
}
rt[n + 1] = ids.size();
WaveletMatrix wm(ids);
int ans = maxv + 1;
vector<int> dp(n + 1, maxv + 1);
dp[1] = 0;
auto dfs2 = [&](auto &&self, int u) -> void {
for (auto [l, r, id]: uop[u]) {
if (r < dp[u]) continue;
ans = min(ans, max(l, dp[u]));
break;
}
for (auto v: adj[u]) {
int cnt = rt[dfn[btn[v]] + 1] - rt[dfn[v]];
int cl = 0, cr = cnt, idx = k;
while (cl < cr) {
int cm = cl + cr >> 1;
int nidx = wm.quantile(cm, rt[dfn[v]], rt[dfn[btn[v]] + 1]);
auto &[ou, l, r] = op[nidx];
if (r >= dp[u] + 1) {
idx = nidx;
cr = cm;
} else {
cl = cm + 1;
}
}
if (idx == k) continue;
auto &[ou, ll, rr] = op[idx];
dp[v] = min(max(ll, dp[u] + 1), maxv + 1);
self(self, v);
}
};
dfs2(dfs2, 1);
if (ans > maxv) cout << -1 << '\n';
else cout << ans << '\n';
}
int main() {
// cin.tie(nullptr), ios::sync_with_stdio(false);
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}