杀树

题目链接
题目大意
给一棵树n个节点,如果树上不存在一个链长大于等于m的链,那么这棵树就死了,去掉一个结点的代价是ai,删除节点后节点的边也就没了,问把这颗树杀死的代价最少是多少?(可以把这个数砍成多个树)
n和m都是5000的范围
题解
树形dp
dp[i][j]表示i为根节点的时候,向下最长长度为j的最小代价
然后转移的话
dp[i][0]就表示删除当前这个节点,因为当前节点删除了,所以子树跟父节点没关系了。。 贡献就是每个子树上dp[v][1~m]的最小值相加。
其他的话就是一个背包?
因为长度不能超过m所以枚举之前子树的长度,再枚举当前子树的长度。
比如: 枚举的之前的子树的长度是j ,当前的是k 。要满足j + k < m
代码

#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
#include <map>
#include <cmath>
#include <set>
#include <queue>
#include <string>
#include <iostream>
#include <stack>
#include <bitset>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const ll mod = 1e9+7;
const int inf  = 0x3f3f3f3f;

struct cmp
{
   
	bool operator()(const pii & a, const pii & b)
	{
   
		return a.second > b.second;
	}
};

const int maxn = 5005;
std::vector<int> vv[maxn];
int dp[maxn][maxn];
int temp[maxn];
int val[maxn];
int n,m;
void dfs(int x,int fa)
{
   
	dp[x][0] = val[x];
	for (int i = 0; i <vv[x].size(); i ++ )
	{
   
		int v = vv[x][i];
		if(v == fa)
			continue;
		dfs(v,x);
		int maxx = inf;
		for (int j = 0; j < m; j ++ )
			maxx = min(maxx,dp[v][j]);
		dp[x][0] += maxx;
		for (int j = 1; j < m; j ++ )
		{
   
			temp[j] = dp[x][j];
			dp[x][j] = inf;
		}
		for (int j = 1; j < m; j ++ )
		{
   
			for (int k = 0; k < m - j; k ++ )
			{
   
				dp[x][max(j,k + 1)] = min(dp[x][max(j,k + 1)], temp[j] + dp[v][k]);
			}
		}
	}
}


int main()
{
   
	scanf("%d%d",&n,&m);
	for (int  i= 1; i <= n; i++ )
		scanf("%d",&val[i]);
	for (int i = 1; i < n; i ++ )
	{
   
		int x,y;
		scanf("%d%d",&x,&y);
		vv[x].push_back(y);
		vv[y].push_back(x);
	}
	dfs(1,0);
	int ans = inf;
	for (int i = 0; i < m; i ++ )
	{
   
		ans = min(ans,dp[1][i]);
	}
	printf("%d\n",ans);
}