Codeforces 1485E Move and Swap-DP

传送门

题意:

一棵n个节点的树,叶子节点的深度都相同,每个节点有一个权值 a i a_i ai,有红蓝两个棋子,初始在根节点1,每轮进行三步操作:

1.红棋子移动到当前所在节点的儿子上

2.蓝棋子移动到当前节点层数+1的任意一个节点

3.交换红蓝两个棋子(可选)

移动到叶子节点后终止移动

每轮移动后获得的分数为棋子所在位置的权值差的绝对值,求最大分数和。

n ≤ 2 ∗ 1 0 5 n\leq 2*10^5 n2105

Solution:

我们分层进行处理

d p i dp_i dpi表示在进行完三步操作后,红节点在 i i i点时能够获得的最大分数

转移可以分成两种,不交换和交换,先讨论不交换的情况,转移方程为
d p i = m a x j d e p j = = d e p i d p p a r e n t i + ∣ a i − a j ∣ dp_i=max_{j}^{dep_j==dep_i}dp_{parent_i}+|a_i-a_j| dpi=maxjdepj==depidpparenti+aiaj
红棋子只能从 i i i的父亲上转移下来,蓝棋子可以随便选位置转移

交换的情况,我们需要枚举交换前红棋子的位置 j j j,转移方程为
d p i = m a x j d e p j = = d e p i d p p a r e n t j + ∣ a i − a j ∣ dp_i=max_j^{dep_j==dep_i}dp_{parent_j}+|a_i-a_j| dpi=maxjdepj==depidpparentj+aiaj
因为是取最大值,我们可以把两个方程的绝对值去掉,变成4个转移方程,这样转移复杂度是 O ( n 2 ) O(n^2) O(n2)的,但是我们可以发现对于每层每类方程程,他们的最优解 j j j是相同的,因此我们转移之前预处理出最大位置 j j j即可,总复杂度 O ( n ) O(n) O(n)

代码:

#include<cstdio>
#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
int T,n;
struct edge{
	int to,next;
}e[400010];
int head[200010],v[200010],fa[200010],sz;
long long dp[200010];
vector<pair<int,int> >vis[200010];
void add(int x,int y)
{
	//cout<<x<<" "<<y<<endl;
	sz++;e[sz].to=y;e[sz].next=head[x];head[x]=sz;
}
void dfs(int x,int f,int dep)
{
	//cout<<x<<" "<<f<<" "<<dep<<endl;
	fa[x]=f;
	//cout<<v[x]<<" "<<x<<" "<<dep<<endl;
	vis[dep].push_back(make_pair(v[x],x));
	for (int i=head[x];i;i=e[i].next)
	{
		int y=e[i].to;
		if (y!=f) dfs(y,x,dep+1);
	}
}
int main()
{
	scanf("%d",&T);
	while (T--)
	{
		scanf("%d",&n);
		for (int x,i=2;i<=n;i++)
			scanf("%d",&x),add(i,x),add(x,i);
		for (int i=2;i<=n;i++) scanf("%d",&v[i]); 
		dfs(1,-1,0);
		dp[0]=0;
		long long ans=0;
		for (int i=1;i<=n;i++)
		{
			if (vis[i].size()==0) break;
			sort(vis[i].begin(),vis[i].end());
			pair<int,int> maxn=vis[i][vis[i].size()-1],minn=vis[i][0];
			long long max1=-1e9,max2=-1e9;
			for (int j=0;j<vis[i].size();j++)
				max1=max(max1,dp[fa[vis[i][j].second]]+vis[i][j].first),
				max2=max(max2,dp[fa[vis[i][j].second]]-vis[i][j].first);
			//cout<<maxn.first<<" "<<maxn.second<<" "<<minn.first<<" "<<minn.second<<endl;
			for (int j=0;j<vis[i].size();j++)
			{
				int x=vis[i][j].second;
				dp[x]=dp[fa[x]]+vis[i][j].first-minn.first;
				dp[x]=max(dp[x],dp[fa[x]]+maxn.first-vis[i][j].first);
				dp[x]=max(dp[x],max1-vis[i][j].first);
				dp[x]=max(dp[x],max2+vis[i][j].first);
				ans=max(dp[x],ans);
				//cout<<dp[x]<<" "<<x<<endl;
			}
			vis[i].clear();
		}
		printf("%lld\n",ans);
		sz=0;
		for (int i=1;i<=n;i++) head[i]=0;	
	}
}