题目大意: 给定一棵树,每个点选择黑、白有对应的代价。定义一棵树的收益为所有黑白点对间路径边权最大值的和 问如何选择每个点的颜色使得收益-代价最大?n3000n\le3000

分析:因为要考虑路径中边的最大值,所以我们可以从小到大考虑每条边,建立KruskalKruskal重构树,将问题转换为树形DP,可以设dpi,jdp_{i,j}表示在ii子树内选了jj个黑点的最大答案,记i的左儿子为lsonlson,右儿子为rsonrson,blackblack表示ii子树内黑点个数,枚举左子树的黑点数black1black1,转移为


dpu,black=maxmax(0,blacksizerson)min(black,sizelsondplson,black1+dprson,blackblack1+valu(black1(sizerson(blackblack1))+(blackblack1)(sizelsonb1))dp_{u,black}=max_{max(0,black-size_{rson})}^{min(black,size_{lson}}dp_{lson,black1}+dp_{rson,black-black1}+val_u*(black1*(size_{rson}-(black-black1))+(black-black1)*(size_{lson}-b1))

KruskalKruskal重构树知识:https://blog.csdn.net/m0_61735576/article/details/124804973

参考代码:

#define LL long long
using namespace std;
const int M=3007; 
struct Edge{
	int u,v;
	LL w;
}edge[M];
vector<int> g[M<<1];
int n;
int color[M],fa[M<<1],siz[M<<1];
LL cost[M],val[M<<1],dp[M<<1][M];

bool Cmp(Edge e1,Edge e2){
	return e1.w<e2.w;
}

int findfa(int x){
	if(fa[x]==x) return x;
	return fa[x]=findfa(fa[x]); 
}


void dfs(int u){
	int v;
	if(u<=n){
		siz[u]=1;
		if(color[u]){
			dp[u][1]=0; dp[u][0]=-cost[u];
		} else {
			dp[u][0]=0; dp[u][1]=-cost[u];
		}
		return;
	}
	int lson=g[u][0];  int rson=g[u][1];
	//cout<<":"<<u<<" "<<lson<<" "<<rson<<"\n";
	dfs(lson); dfs(rson);
	if(siz[lson]>siz[rson]) swap(lson,rson);
	siz[u]=siz[lson]+siz[rson];
	for(int b=0;b<=siz[u];b++){
		dp[u][b]=-1e18;
		for(int b1=max(0,b-siz[rson]);b1<=min(b,siz[lson]);b1++){
			dp[u][b]=max(dp[u][b],dp[lson][b1]+dp[rson][b-b1]+val[u]*(b1*(siz[rson]-(b-b1))+(b-b1)*(siz[lson]-b1)));
		}
	}
	
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++)
		scanf("%d",&color[i]);
		
	for(int i=1;i<=n;i++)
		scanf("%lld",&cost[i]);
		
	for(int i=1;i<n;i++)
		scanf("%d%d%lld",&edge[i].u,&edge[i].v,&edge[i].w);
	
	sort(edge+1,edge+n,Cmp);
	int cnt=n;
	for(int i=1;i<=2*n;i++) fa[i]=i;
	for(int i=1;i<n;i++){
		int lfa=findfa(edge[i].u);
		int rfa=findfa(edge[i].v);
		if(lfa==rfa) continue;
		cnt++;
		g[cnt].push_back(lfa);
		g[cnt].push_back(rfa);
		fa[lfa]=cnt; fa[rfa]=cnt;
		val[cnt]=edge[i].w;
	}
	dfs(findfa(1));
	LL anq=0;
	for(int i=0;i<=n;i++){
		//printf("size:%d :%d\n",i,siz[i]);
		//printf("anq:%lld\n",dp[findfa(1)][i]);
		anq=max(anq,dp[findfa(1)][i]);
	}
	printf("%lld\n",anq);
}