看到异或最值,显然想到线性基。
用树上倍增的方法,维护当前点\(x\)到倍增父节点\(fa[x][i]\)这条路径上的线性基,在倍增的时候暴力合并即可。
注意这个线性基的倍增数组是没有包括最后一个点的信息的,需要特殊处理。然后就搞完了。
时间复杂度\(O(n*log_n*log_v+q*log_n*log_v)\)
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define R register
#define LL long long
const int MAXN=2e4+10;
const int MAXQ=2e5+10;
int n,q;
LL val[MAXN];
int lg[MAXN];
int head[MAXN],cnt;
struct edge { int to,next; } e[MAXN<<1];
inline void add(int x,int y) { e[++cnt]={y,head[x]}; head[x]=cnt; }
class Basic {
#define MB 63
public:
LL p[MB+1];
Basic() { memset(p,0,sizeof(p)); }
inline void clear() { memset(p,0,sizeof(p)); }
inline void ins(LL x) {
for(R int i=MB;i>=0;i--)
if(x&(1LL<<i)) {
if(!p[i]) { p[i]=x; return ;}
else x^=p[i];
}
}
inline LL ask() {
LL ans=0;
for(R int i=MB;i>=0;i--)
if((ans^p[i])>ans) ans^=p[i];
return ans;
}
};
inline Basic operator + (Basic x,Basic &y) {
for(R int i=MB;i>=0;i--)
if(y.p[i]) x.ins(y.p[i]);
return x;
}
int dep[MAXN],fa[MAXN][15];
Basic bas[MAXN][15];
inline void dfs(int x,int fx) {
dep[x]=dep[fx]+1;
fa[x][0]=fx; bas[x][0].ins(val[x]);
for(R int i=1;i<=lg[dep[x]];i++) {
fa[x][i]=fa[fa[x][i-1]][i-1];
bas[x][i]=bas[x][i]+bas[x][i-1];
bas[x][i]=bas[x][i]+bas[fa[x][i-1]][i-1];
}
for(R int i=head[x];i;i=e[i].next) {
int y=e[i].to;
if(y==fx) continue;
dfs(y,x);
}
}
Basic Ans;
inline LL ask(int x,int y) {
if(dep[x]<dep[y]) swap(x,y);
Ans.clear();
while(dep[x]>dep[y]) {
Ans=Ans+bas[x][lg[dep[x]-dep[y]]];
x=fa[x][lg[dep[x]-dep[y]]];
}
if(x==y) {
Ans.ins(val[x]);
return Ans.ask();
}
for(R int i=lg[dep[x]];i>=0;i--)
if(fa[x][i]!=fa[y][i]) {
Ans=Ans+bas[x][i];
Ans=Ans+bas[y][i];
x=fa[x][i];
y=fa[y][i];
}
Ans.ins(val[x]);
Ans.ins(val[y]);
Ans.ins(val[fa[x][0]]);
return Ans.ask();
}
inline void Init() {
scanf("%d%d",&n,&q);
for(R int i=1;i<=n;i++) scanf("%lld",&val[i]);
for(R int i=1;i<n;i++ ) {
int x,y; scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
for(R int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
dfs(1,0);
}
inline void Solve() {
while(q--) {
int x,y;
scanf("%d%d",&x,&y);
printf("%lld\n",ask(x,y));
}
}
int main() {
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
Init();
Solve();
return 0;
}