Codeforces 1485E Move and Swap-DP
题意:
一棵n个节点的树,叶子节点的深度都相同,每个节点有一个权值 a i a_i ai,有红蓝两个棋子,初始在根节点1,每轮进行三步操作:
1.红棋子移动到当前所在节点的儿子上
2.蓝棋子移动到当前节点层数+1的任意一个节点
3.交换红蓝两个棋子(可选)
移动到叶子节点后终止移动
每轮移动后获得的分数为棋子所在位置的权值差的绝对值,求最大分数和。
n ≤ 2 ∗ 1 0 5 n\leq 2*10^5 n≤2∗105
Solution:
我们分层进行处理
d p i dp_i dpi表示在进行完三步操作后,红节点在 i i i点时能够获得的最大分数
转移可以分成两种,不交换和交换,先讨论不交换的情况,转移方程为
d p i = m a x j d e p j = = d e p i d p p a r e n t i + ∣ a i − a j ∣ dp_i=max_{j}^{dep_j==dep_i}dp_{parent_i}+|a_i-a_j| dpi=maxjdepj==depidpparenti+∣ai−aj∣
红棋子只能从 i i i的父亲上转移下来,蓝棋子可以随便选位置转移
交换的情况,我们需要枚举交换前红棋子的位置 j j j,转移方程为
d p i = m a x j d e p j = = d e p i d p p a r e n t j + ∣ a i − a j ∣ dp_i=max_j^{dep_j==dep_i}dp_{parent_j}+|a_i-a_j| dpi=maxjdepj==depidpparentj+∣ai−aj∣
因为是取最大值,我们可以把两个方程的绝对值去掉,变成4个转移方程,这样转移复杂度是 O ( n 2 ) O(n^2) O(n2)的,但是我们可以发现对于每层每类方程程,他们的最优解 j j j是相同的,因此我们转移之前预处理出最大位置 j j j即可,总复杂度 O ( n ) O(n) O(n)。
代码:
#include<cstdio>
#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
int T,n;
struct edge{
int to,next;
}e[400010];
int head[200010],v[200010],fa[200010],sz;
long long dp[200010];
vector<pair<int,int> >vis[200010];
void add(int x,int y)
{
//cout<<x<<" "<<y<<endl;
sz++;e[sz].to=y;e[sz].next=head[x];head[x]=sz;
}
void dfs(int x,int f,int dep)
{
//cout<<x<<" "<<f<<" "<<dep<<endl;
fa[x]=f;
//cout<<v[x]<<" "<<x<<" "<<dep<<endl;
vis[dep].push_back(make_pair(v[x],x));
for (int i=head[x];i;i=e[i].next)
{
int y=e[i].to;
if (y!=f) dfs(y,x,dep+1);
}
}
int main()
{
scanf("%d",&T);
while (T--)
{
scanf("%d",&n);
for (int x,i=2;i<=n;i++)
scanf("%d",&x),add(i,x),add(x,i);
for (int i=2;i<=n;i++) scanf("%d",&v[i]);
dfs(1,-1,0);
dp[0]=0;
long long ans=0;
for (int i=1;i<=n;i++)
{
if (vis[i].size()==0) break;
sort(vis[i].begin(),vis[i].end());
pair<int,int> maxn=vis[i][vis[i].size()-1],minn=vis[i][0];
long long max1=-1e9,max2=-1e9;
for (int j=0;j<vis[i].size();j++)
max1=max(max1,dp[fa[vis[i][j].second]]+vis[i][j].first),
max2=max(max2,dp[fa[vis[i][j].second]]-vis[i][j].first);
//cout<<maxn.first<<" "<<maxn.second<<" "<<minn.first<<" "<<minn.second<<endl;
for (int j=0;j<vis[i].size();j++)
{
int x=vis[i][j].second;
dp[x]=dp[fa[x]]+vis[i][j].first-minn.first;
dp[x]=max(dp[x],dp[fa[x]]+maxn.first-vis[i][j].first);
dp[x]=max(dp[x],max1-vis[i][j].first);
dp[x]=max(dp[x],max2+vis[i][j].first);
ans=max(dp[x],ans);
//cout<<dp[x]<<" "<<x<<endl;
}
vis[i].clear();
}
printf("%lld\n",ans);
sz=0;
for (int i=1;i<=n;i++) head[i]=0;
}
}