HDU7215多校 - Weighted Beautiful Tree

题意

  • 给出一棵 nn 个节点的树,每个点有点权 wiw_i,每条边有边权。
  • 现在需要修改一些点的点权,使得 wuwedgewvw_u \leq w_{edge} \leq w_v,即任意两点之间的边权的数值要在这两个点的点权之间。
  • 修改 ii 节点点权的代价是:d×cid\times c_i,其中 dd 是修改后点权的变化量。
  • 问最小的代价总和。

思路

初步思路

  • 考虑 DP。
  • 既然要修改点权,我们肯定要让 DP 的某个状态表示跟点权有关的信息
  • dp[i][j]dp[i][j] 为将 ii 节点的点权修改为 jj ,它子树的总代价,但这样杂度太大。
  • 不过是否可以推出一些小结论,来简化这个状态?

进一步

  • 可以发现,一种较优的修改点权的策略是:
  • 对于一个点,可以尝试以下策略:1.点权不变2.将点权修改为父边的边权3.将点权修改为儿子边的边权。
  • dp[i][0/1/2]dp[i][0/1/2] 为将 ii 节点的点权执行上述 1,2,3 操作,它子树的总代价。
  • 问题:对于操作 3,这样的状态表示不出将当前节点的点权具体修改成了多少。很难转移。

最终思路

  • 我们发现,转移过程我们只关心儿子节点点权跟儿子边边权的大小关系。
    • 例如,如果儿子的点权比儿子边边权小,父节点点权就得比边权大。
    • 如果儿子的点权比儿子边边权大,父节点点权就得比边权小。
    • 那么,我先确定将父节点的点权改为多少,再根据儿子边的边权,从儿子点转移过来。
  • dp[cur][0/1]dp[cur][0/1] 为,将 curcur 节点的点权修改得 \leq 或者 \geq curcur 的父边,它子树的最小代价。
  • 我先确定将父节点的点权改为多少。根据上面的结论,尝试:
  • 1.点权不变
  • 2.将点权修改为父边的边权
  • 3.将点权修改为儿子边的边权。
  • 按照上面的策略枚举该节点的新的权值。
  • 对于一个枚举的权值 numnum,有以下转移:
    • tmp+=dp[nxt][0]tmp += dp[nxt][0],对于 num>wesonnum>w_{e_{son}}nxtnxt 节点
    • tmp+=dp[nxt][1]tmp += dp[nxt][1],对于 num<wesonnum<w_{e_{son}}nxtnxt 节点
    • tmp+=min(dp[nxt][0],dp[nxt][1])tmp += \min(dp[nxt][0],dp[nxt][1]),对于 num=wesonnum=w_{e_{son}}nxtnxt 节点
    • dp[cur][0/1]=min(dp[cur][0/1],tmp)dp[cur][0/1]=\min(dp[cur][0/1],tmp)
    • 其中,dp[cur][0/1]dp[cur][0/1] 中的 [0/1][0/1] 由枚举的权值 numnum 决定。
  • 对于一个枚举的权值 numnum,上述过程要做到 O(1)O(1)
  • 如何优化上面过程的复杂度?
  • 先将当前节点相连的所有儿子节点按照 wesonw_{e_{son}} 升序排序。
  • 预处理前缀和,分别用 pre_sum0,pre_sum1,pre_sum01pre\_sum0,pre\_sum1,pre\_sum01 数组预处理出 dp[nxt][0],dp[nxt][1],min(dp[nxt][0],dp[nxt][1])dp[nxt][0],dp[nxt][1],\min(dp[nxt][0],dp[nxt][1])的前缀和。
  • 接着,权值 numnum 也需要升序排序。
  • 我们发现,因为我们对儿子边的边权排过序了,设排序后的儿子节点为 nxt1,nxt2,nxt3,nxt_1,nxt_2,nxt_3,\dotstmp+=dp[nxt][0]tmp += dp[nxt][0]nxtnxt 是分布在一个连续的一段 nxtnxt 区间里面,dp[nxt][1]dp[nxt][1]min(dp[nxt][0],dp[nxt][1])\min(dp[nxt][0],dp[nxt][1]) 也一样。
  • 可以前缀和累加了。
  • 如果我们用 limL,limRlimL,limR 代表这三个区间的两个分断点,那么随着枚举的权值 numnum 的增大,limL,limRlimL,limR 也单调递增。

代码

#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <cstdio>
#include <iostream>
#include <vector>
#include <algorithm>
#define int long long
const int N		= 1e6+10;
const int INF	= 1e18;
using namespace std;

int dp[N][2];
vector<int> G[N],W[N];
int ai[N],ci[N];
int n;

vector<int> vec[N];
vector< pair<int,int> > vec_son[N];	//<连接儿子的边权,儿子的编号>
vector<int> pre_sum0[N],pre_sum1[N],pre_sum01[N];

void DFS(int cur,int pre,int pre_len)
{
	vec[cur].push_back(pre_len);
	vec[cur].push_back(ai[cur]);
	
	for (int i=0; i<G[cur].size(); i++)
	{
		int nxt = G[cur][i];
		int len = W[cur][i];
		if(nxt==pre)continue;
		
		vec[cur].push_back(len);
		vec_son[cur].push_back({len,nxt});
		DFS(nxt, cur, len);
	}
	
	sort(vec[cur].begin(), vec[cur].end());
	sort(vec_son[cur].begin(), vec_son[cur].end());
	
	pre_sum0[cur].resize(vec_son[cur].size()+1);
	pre_sum1[cur].resize(vec_son[cur].size()+1);
	pre_sum01[cur].resize(vec_son[cur].size()+1);
	
	for (int i=0; i<vec_son[cur].size()+1; i++)
	{
		pre_sum0[cur][i] = pre_sum1[cur][i] = pre_sum01[cur][i] = 0;
	}
	
	for (int i=0; i<vec_son[cur].size(); i++)
	{
		int nxt = vec_son[cur][i].second;
		pre_sum0[cur][i] = (i==0 ? 0:(pre_sum0[cur][i-1])) + dp[nxt][0];
		pre_sum1[cur][i] = (i==0 ? 0:(pre_sum1[cur][i-1])) + dp[nxt][1];
		pre_sum01[cur][i] = (i==0 ? 0:(pre_sum01[cur][i-1])) + min(dp[nxt][0],dp[nxt][1]);
	}
	
	int L=-1, R=-1;
	
	dp[cur][0] = dp[cur][1] = INF;
	
	for (auto num : vec[cur])
	{
		while (L+1<vec_son[cur].size() && vec_son[cur][L+1].first < num )
			L++;
		
		while (R+1<vec_son[cur].size() && vec_son[cur][R+1].first <= num )
			R++;
		
		int sum = (L<0?0:pre_sum0[cur][L]) + (R<0?0:(pre_sum01[cur][R])) - (L<0?0:(pre_sum01[cur][L]))  +  (vec_son[cur].empty()?0:(pre_sum1[cur][vec_son[cur].size()-1]-(R<0?0:pre_sum1[cur][R])));
		
		if(num>=pre_len)
			dp[cur][1] = min(dp[cur][1], sum + ci[cur]*abs(ai[cur]-num));//,printf("num=%d cur=%d sum=%d, %d \n",num,cur,sum , ci[cur]*abs(ai[cur]-num));
		if(num<=pre_len)
			dp[cur][0] = min(dp[cur][0], sum + ci[cur]*abs(ai[cur]-num));
	}
	
	//printf("dp%d = %d,%d\n",cur,dp[cur][0],dp[cur][1]);
}

void Sol()
{
	DFS(1, 0, 0);
	int ans = min(dp[1][0],dp[1][1]);
	printf("%lld\n",ans);
	
}

signed main()
{
	int tt;
	scanf("%lld",&tt);
	while (tt--)
	{
		scanf("%lld",&n);
		for (int i=1; i<=n; i++)
		{
			G[i].clear();
			W[i].clear();
			vec[i].clear();
			vec_son[i].clear();
			pre_sum0[i].clear();
			pre_sum1[i].clear();
			pre_sum01[i].clear();
			dp[i][0] = dp[i][1] = INF;
		}
		for (int i=1; i<=n; i++)
		{
			scanf("%lld",&ci[i]);
		}
		for (int i=1; i<=n; i++)
		{
			scanf("%lld",&ai[i]);
		}
		for (int i=1; i<=n-1; i++)
		{
			int u,v,w;
			scanf("%lld %lld %lld",&u,&v,&w);
			G[u].push_back(v);
			W[u].push_back(w);
			G[v].push_back(u);
			W[v].push_back(w);
		}
		Sol();
//////////////////
	}
	
	
	return 0;
}