题目链接

题目大意

给一棵有点权的树。根节点是1号节点,让找两个点,然后保留这两个点到根节点的路径,把其他的点删了。使得图上剩余点的权值异或和最大。

怎么说呢? 没干出来,,想到了字典树求一对异或值最大的点。
然后 求一个树上点的前缀和 想着枚举一个lca然后再枚举每一个子树,,但是这个肯定是超时的,不敢写。。没写。。也不会写。。

题解:

树上启发式合并或啥啥树。。没听过我太菜了。。
这个的时间复杂度不怎么会算。。。

枚举分叉节点建立一个子树的字典树,然后在另一个子树上挨个解决问题。
然后就是怎么建一个子树上的字典树了。。 可以用树上启发式合并。
代码:

#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 1e5+5;

int tree[maxn * 40][2];
int sum[maxn * 40][2];
int cnt = 0;
void add(int x,int n)
{
   
	int p = 0;
	for (int i = 30; i >= 0; i -- )
	{
   
		int t = x >>i & 1;
		if(tree[p][t] == 0)
			tree[p][t] = ++cnt;
		sum[p][t] += n;
		p = tree[p][t];
	}
}

int findx(int x)
{
   
	int ans =0 ;
	int p= 0 ;
	for (int i = 30; i >= 0; i -- )
	{
   
		int t = x >> i & 1;
		if(sum[p][!t])
		{
   
			ans += 1 << i;
			p = tree[p][!t];
		}
		else
			p = tree[p][t];
	}
	return ans;
}

std::vector<int> vv[maxn];
int a[maxn];
int son[maxn];
int num[maxn];
void dfs(int x,int fa)
{
   
	
	a[x] ^= a[fa];
	num[x] ++ ;
	for (int i =0 ;i < vv[x].size(); i ++ )
	{
   
		int v = vv[x][i];
		if(v == fa)
			continue;
		dfs(v,x);
		num[x] += num[v];
		if(num[v] > num[son[x]])
			son[x] = v;
	}
}
int ans =0 ;
int flag = 0;
int k = 0;
std::vector<int> vp;
void solve(int x,int fa)
{
   
	ans = max(ans,findx(a[x] ^ a[k]));
	// printf("%d %d %d\n",x,k,ans);
	for (int i = 0; i < vv[x].size(); i ++ )
	{
   
		int v = vv[x][i];
		if(v == fa || v == flag)
			continue;
		solve(v,x);
		vp.push_back(v);
	}
	if(fa == k)
	{
   
		for (int i =0 ; i < vp.size(); i ++ )
		{
   
			add(a[vp[i]],1);
		}
		vp.clear();
	}
}
void del(int x,int fa)
{
   
	add(a[x],-1);
	for (int i =0 ; i< vv[x].size(); i ++ )
	{
   
		int v = vv[x][i];
		if(v == fa)
			continue;
		del(v,x);
	}
}
void dfs2(int x,int fa,int f)
{
   
	for (int i = 0; i < vv[x].size(); i ++ )
	{
   
		int v = vv[x][i];
		if(v == fa || v == son[x])
			continue;
		dfs2(v,x,1);
	}
	if(son[x])
		dfs2(son[x],x,0);
	flag = son[x];
	k = x;
	add(a[x],1);
	solve(x,fa);
	if(f)
	{
   
		del(x,fa);
	}
}



int main()
{
   
	int n;
	scanf("%d",&n);
	for (int i = 1; i <= n; i ++ )
	{
   
		scanf("%d",&a[i]);
	}
	for (int i = 1; i < n; i ++ )
	{
   
		int x,y;
		scanf("%d%d",&x,&y);
		vv[x].push_back(y);
		vv[y].push_back(x);
	}
	dfs(1,0);
	dfs2(1,0,1);
	printf("%d\n",ans);
}