题面:
题解:
树链剖分
维护点x到根节点的异或前缀和,更改某一点的值视为更改某棵子树的值。
询问的时候分类讨论lca即可。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<set>
#include<vector>
#define ll long long
#define llu unsigned ll
#define lc (cnt<<1)
#define rc (cnt<<1|1)
using namespace std;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const int maxn=50100;
int n,m;
int head[maxn],ver[maxn<<1],nt[maxn<<1];
int f[maxn],d[maxn],si[maxn],son[maxn],rk[maxn];
int top[maxn],id[maxn],a[maxn];
int tot=1,cnt=0;
struct tree
{
bool val[maxn];
struct node
{
int l,r;
int sum,laz;
}t[maxn<<2];
void pushup(int cnt)
{
t[cnt].sum=t[lc].sum+t[rc].sum;
}
void build(int l,int r,int cnt)
{
t[cnt].l=l,t[cnt].r=r;
t[cnt].laz=t[cnt].sum=0;
if(l==r)
{
t[cnt].sum=val[rk[l]];
return ;
}
int mid=(l+r)>>1;
build(l,mid,lc);
build(mid+1,r,rc);
pushup(cnt);
}
void pushdown(int cnt)
{
if(t[cnt].laz)
{
t[lc].sum=t[lc].r-t[lc].l+1-t[lc].sum;
t[rc].sum=t[rc].r-t[rc].l+1-t[rc].sum;
t[lc].laz^=1,t[rc].laz^=1;
t[cnt].laz=0;
}
}
void change(int l,int r,int cnt)
{
if(l<=t[cnt].l&&t[cnt].r<=r)
{
t[cnt].sum=t[cnt].r-t[cnt].l+1-t[cnt].sum;
t[cnt].laz^=1;
return ;
}
pushdown(cnt);
if(l<=t[lc].r) change(l,r,lc);
if(r>=t[rc].l) change(l,r,rc);
pushup(cnt);
}
int ask(int l,int r,int cnt)
{
if(l<=t[cnt].l&&t[cnt].r<=r)
return t[cnt].sum;
pushdown(cnt);
int ans=0;
if(l<=t[lc].r) ans+=ask(l,r,lc);
if(r>=t[rc].l) ans+=ask(l,r,rc);
return ans;
}
}tt[31];
void add(int x,int y)
{
ver[++tot]=y,nt[tot]=head[x],head[x]=tot;
}
void dfs1(int x,int fa)
{
int max_son=0;
si[x]=1;
for(int i=0;i<=30;i++)
tt[i].val[x]=((a[x]>>i)&1)^tt[i].val[fa];
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y==fa) continue;
d[y]=d[x]+1;
f[y]=x;
dfs1(y,x);
si[x]+=si[y];
if(si[y]>max_son)max_son=si[y],son[x]=y;
}
}
void dfs2(int x,int t)
{
top[x]=t;
id[x]=++cnt;
rk[cnt]=x;
if(!son[x]) return ;
dfs2(son[x],t);
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y!=son[x]&&y!=f[x])
dfs2(y,y);
}
}
int ask(int i,int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]])
swap(x,y);
ans+=tt[i].ask(id[top[x]],id[x],1);
x=f[top[x]];
}
if(id[x]>id[y]) swap(x,y);
ans+=tt[i].ask(id[x],id[y],1);
return ans;
}
int lca(int x,int y)
{
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]]) swap(x,y);
x=f[top[x]];
}
return id[x]>id[y]?y:x;
}
void dofor1(int x,int val)
{
for(int i=0;i<=30;i++)
{
int pre=(a[x]>>i)&1,now=(val>>i)&1;
if(pre^now)
tt[i].change(id[x],id[x]+si[x]-1,1);
}
a[x]=val;
}
void dofor2(int x,int y)
{
int lcc=lca(x,y);
ll ans=0;
for(int i=0;i<=30;i++)
{
int sumx1=ask(i,x,lcc),sumy1=ask(i,y,lcc);
int sumx0=d[x]-d[lcc]+1-sumx1,sumy0=d[y]-d[lcc]+1-sumy1;
ans+=(1ll*sumx1*sumx0+sumy1*sumy0)*(1ll<<i);
//注意lca在此处的分类考虑后的一致性
if((a[lcc]>>i)&1) ans+=(1ll*sumx1*sumy1+1ll*sumx0*sumy0)*(1ll<<i);
else ans+=(1ll*sumx1*sumy0+1ll*sumx0*sumy1)*(1ll<<i);
}
printf("%lld\n",ans);
}
int main(void)
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
int x,y;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs1(1,0);
dfs2(1,1);
for(int i=0;i<=30;i++)
tt[i].build(1,cnt,1);
int id;
for(int i=1;i<=m;i++)
{
scanf("%d%d%d",&id,&x,&y);
if(id==1) dofor1(x,y);
else dofor2(x,y);
}
return 0;
}