前置知识
- 树链剖分
- 可持久化字典树
思路
首先按照 dfs 序(记作 )插入可持久化字典树中。对于查询点
的答案,相当于从可持久化字典树中查询区间
,由树剖可知,父亲的 dfn 区间数量级不会超过
个,所以相当于在可持久化字典树中查询
个区间的答案。
代码实现
时间复杂度 。由于每个点都要查询,树剖不会跑满,但是依旧需要注意常数。
以下代码跑了 2286 ms。(和某些 代码跑的差不多)
#include <bits/stdc++.h>
using namespace std;
#pragma GCC optimize(3,"Ofast","inline")
#define ull unsigned long long
// #define int long long
#define pii array<int, 2>
#define endl "\n"
const int N = 5e5;
int cnt, rt[N + 9], ch[N * 33 + 9][2], val[N * 33 + 9];
void insert(int o, int lst, int v) {
for (int i = 29; i >= 0; i--) {
val[o] = val[lst] + 1; // 在原版本的基础上更新
if ((v >> i & 1) == 0) {
if (!ch[o][0])
ch[o][0] = ++cnt;
ch[o][1] = ch[lst][1];
o = ch[o][0];
lst = ch[lst][0];
} else {
if (!ch[o][1])
ch[o][1] = ++cnt;
ch[o][0] = ch[lst][0];
o = ch[o][1];
lst = ch[lst][1];
}
}
val[o] = val[lst] + 1;
}
vector<int> v1,v2;
int query(int v) {
int ret = 0;
for (int i = 29; i >= 0; i--) {
int t = v >> i & 1;
int tmp = 0;
for (auto &o : v1) tmp += val[ch[o][!t]];
for (auto &o : v2) tmp -= val[ch[o][!t]];
if (tmp) {
ret += (1 << i);
t ^= 1;
}
for (auto &o : v1) o = ch[o][t];
for (auto &o : v2) o = ch[o][t];
}
return ret;
}
vector<int> tr[N + 9];
int w[N + 9], sz[N + 9], hs[N + 9], pa[N + 9];
int dfn[N + 9], top[N + 9], id[N + 9], now;
void dfs(int u, int fa) {
pa[u] = fa;
sz[u] = 1;
for (auto v : tr[u]) {
if (v == fa)
continue;
dfs(v, u);
sz[u] += sz[v];
if (sz[v] > sz[hs[u]])
hs[u] = v;
}
}
void dfs2(int u, int t) {
dfn[u] = ++now;
id[now] = u;
top[u] = t;
if (hs[u])
dfs2(hs[u], t);
for (auto v : tr[u]) {
if (v == pa[u] || v == hs[u])
continue;
dfs2(v, v);
}
}
int n;
int qry(int u) {
v1.clear();v2.clear();
vector<pii> seg;
seg.push_back({dfn[u]+1,dfn[u]+sz[u]-1});
for (int p = u; p; p = pa[top[p]]) {
seg.push_back({dfn[top[p]],dfn[p]});
}
sort(seg.begin(),seg.end());
int pre = 1;
for (auto [l,r] : seg) {
if (pre < l) {
v1.push_back(rt[l-1]), v2.push_back(rt[pre-1]);
}
pre = r+1;
}
if (pre <= n) v1.push_back(rt[n]), v2.push_back(rt[pre-1]);
return query(w[u]);
}
void solve() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> w[i];
}
for (int i = 2; i <= n; i++) {
int u, v;
cin >> u >> v;
tr[u].push_back(v);
tr[v].push_back(u);
}
dfs(1, 0);
dfs2(1, 1);
for (int i = 1; i <= n; i++) {
rt[i] = ++cnt;
insert(rt[i],rt[i-1],w[id[i]]);
}
for (int i = 1; i <= n; i++) {
cout << qry(i) << " ";
}
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t = 1;
for (int i = 0; i < t; i++) {
solve();
}
return 0;
}