这道题要处理节点与父亲节点的前缀,而不是dfs的区间前缀。
当查询的时候对每一位只需要查询sum[x][!t]+sum[y][!t] > sum[lca][!t]就行了。
倍增
#pragma GCC optimize(2)
#pragma comment(linker, “/ STACK : 1024000000, 1024000000”)
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;
int Case = 1;
int n, m, cc[maxn], cnt, head[maxn], dep[maxn];
int p[30][maxn];
struct node {
int val, v, nex;
} CC[maxn << 1];
void add(int u, int v) {
cnt++;
CC[cnt].v = v;
CC[cnt].nex = head[u];
head[u] = cnt;
}
int root[maxn << 6];
void init() {
cnt = 0;
memset(head, 0, sizeof(head));
}
struct Trie {
int nex[maxn << 6][2], tot, sum[maxn << 6][2];
int newnode() {
nex[tot][0] = nex[tot][1] = 0;
return tot++;
}
void init() {
tot = 1;
nex[0][0] = nex[0][1] = 0;
}
int insert(int x, int id) {
int t = newnode(), res = t;
for (int i = 16; i >= 0; i--) {
int tx = (1 << i) & x;
tx = (tx >> i);
sum[t][tx] = sum[id][tx] + 1;
sum[t][!tx] = sum[id][!tx];
nex[t][tx] = newnode();
nex[t][!tx] = nex[id][!tx];
t = nex[t][tx];
id = nex[id][tx];
}
return res;
}
int query(int lca, int x, int y, int s) {
int res = 0, ans = cc[lca] ^ s;
lca = root[lca];
for (int i = 16; i >= 0; i--) {
int t = (1 << i) & s;
t = (t >> i);
if (sum[y][!t] + sum[x][!t] - 2 * sum[lca][!t] > 0) {
res += (1 << i);
x = nex[x][!t];
y = nex[y][!t];
lca = nex[lca][!t];
} else {
x = nex[x][t];
y = nex[y][t];
lca = nex[lca][t];
}
}
return max(res, ans);
}
} trie;
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
p[0][u] = fa;
root[u] = trie.insert(cc[u], root[fa]);
for (int i = head[u]; i; i = CC[i].nex) {
int v = CC[i].v;
if (v != fa) {
dfs(v, u);
}
}
}
void init_rmq() {
for (int j = 1; (1 << j) < n; ++j)
for (int i = 1; i <= n; ++i) p[j][i] = p[j - 1][p[j - 1][i]];
}
int ask(int x, int y) {
if (dep[x] > dep[y]) swap(x, y);
for (int i = 16; i >= 0; --i)
if (dep[y] - dep[x] >= 1 << i) y = p[i][y];
if (x == y) return y;
for (int i = 16; i >= 0; --i)
if (p[i][x] && p[i][x] != p[i][y]) x = p[i][x], y = p[i][y];
return p[0][x];
}
void solve() {
init();
trie.init();
for (int i = 1; i <= n; i++) scanf("%d", &cc[i]);
for (int i = 1; i < n; i++) {
int v, u;
scanf("%d%d", &v, &u);
add(v, u);
add(u, v);
}
memset(p, 0, sizeof(p));
dfs(1, 0);
init_rmq();
for (int i = 1; i <= m; i++) {
int u, v, x;
scanf("%d%d%d", &v, &u, &x);
printf("%d\n", trie.query(ask(u, v), root[v], root[u], x));
}
return;
}
int main() {
// g++ -std=c++11 -o2 1.cpp -o f && ./f < in.txt
// ios::sync_with_stdio(false);
#ifndef ONLINE_JUDGE
freopen("in.txt", "r", stdin);
// freopen("out.txt","w",stdout);
#endif
while (scanf("%d%d", &n, &m) == 2) {
solve();
}
return 0;
}
树剖(一开始忘记清零,T舒服了。。)
#pragma GCC optimize(2)
#pragma comment(linker, “/ STACK : 1024000000, 1024000000”)
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;
int Case = 1;
int n, m, cc[maxn], cnt, head[maxn];
struct node {
int val, v, nex;
} CC[maxn << 1];
void add(int u, int v) {
cnt++;
CC[cnt].v = v;
CC[cnt].nex = head[u];
head[u] = cnt;
}
int siz[maxn], son[maxn], top[maxn];
int dep[maxn], father[maxn], root[maxn << 6];
void init() {
cnt = 0;
memset(head, 0, sizeof(head));
memset(son, 0, sizeof(int)*(n+4));
memset(top, 0, sizeof(int)*(n+4));
}
struct Trie {
int nex[maxn << 6][2], tot, sum[maxn << 6][2];
int newnode() {
nex[tot][0] = nex[tot][1] = 0;
return tot++;
}
void init() {
tot = 1;
nex[0][0] = nex[0][1] = 0;
}
int insert(int x, int id) {
int t = newnode(), res = t;
for (int i = 16; i >= 0; i--) {
int tx = (1 << i) & x;
tx = (tx >> i);
sum[t][tx] = sum[id][tx] + 1;
sum[t][!tx] = sum[id][!tx];
nex[t][tx] = newnode();
nex[t][!tx] = nex[id][!tx];
t = nex[t][tx];
id = nex[id][tx];
}
return res;
}
int query(int lca, int x, int y, int s) {
int res = 0, ans = cc[lca] ^ s;
lca = root[lca];
for (int i = 16; i >= 0; i--) {
int t = (1 << i) & s;
t = (t >> i);
if (sum[y][!t] + sum[x][!t] - 2 * sum[lca][!t] > 0) {
res += (1 << i);
x = nex[x][!t];
y = nex[y][!t];
lca = nex[lca][!t];
} else {
x = nex[x][t];
y = nex[y][t];
lca = nex[lca][t];
}
}
return max(res, ans);
}
} trie;
void dfs1(int u, int f) {
siz[u] = 1;
father[u] = f;
dep[u] = dep[f] + 1;
root[u] = trie.insert(cc[u], root[f]);
for (int i = head[u]; i; i = CC[i].nex) {
int v = CC[i].v;
if (v != f) {
dfs1(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
}
void dfs2(int u, int f) {
if (son[u]) {
top[son[u]] = top[u];
dfs2(son[u], u);
}
for (int i = head[u]; i; i = CC[i].nex) {
int v = CC[i].v;
if (!top[v]) {
top[v] = v;
dfs2(v, u);
}
}
}
int ask(int x, int y) {
int fx = top[x], fy = top[y];
while (fx != fy) {
if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
x = father[fx];
fx = top[x];
}
if (dep[x] > dep[y]) swap(x, y);
return x;
}
void solve() {
init();
trie.init();
for (int i = 1; i <= n; i++) scanf("%d", &cc[i]);
for (int i = 1; i < n; i++) {
int v, u;
scanf("%d%d", &v, &u);
add(v, u);
add(u, v);
}
dfs1(1, 0);
top[1] = 1;
dfs2(1, 0);
for (int i = 1; i <= m; i++) {
int u, v, x;
scanf("%d%d%d", &v, &u, &x);
printf("%d\n", trie.query(ask(u, v), root[v], root[u], x));
}
return;
}
int main() {
// g++ -std=c++11 -o2 1.cpp -o f && ./f < in.txt
// ios::sync_with_stdio(false);
#ifndef ONLINE_JUDGE
freopen("in.txt", "r", stdin);
// freopen("out.txt","w",stdout);
#endif
while (scanf("%d%d", &n, &m) == 2) {
solve();
}
return 0;
}