D、杀树

时间限制:C/C++ 1秒,其他语言2秒
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld

题目描述

牛牛学习了树。

给出一棵节点数为 n 的树,删去一个点 i 的代价为 a_i,一条链的长度定义为路径上 的个数。一棵树死了,满足不存在一条长度 的链,牛牛希望用最少代价杀死这棵树。

输入描述:

第一行给出 n,l

第二行给出 n 个整数分别代表 a_i

接下来 n-1 行,每行给出 u,v 表示有一条 u 到 v 的边。

输出描述:

输出一个整数,表示最小的代价。
示例1

输入

复制
5 2
1 2 3 4 2
1 3
2 3
3 4
4 5

输出

复制
5

备注:

对于  的数据有 

对于 的数据有

对于 的数据有

解题思路

树形背包,我是通过图片说明 保存的树码量小,符合我这种懒人,在这个基础上。
通过题目给的范围小于等于5000可以满足平方级复杂度,开一个二维数组去维护。
图片说明 表示以i为根的子树,向下传j个长度的最小值,那么通过dfs从叶子往根节点推。
首先,初始状态图片说明表示去掉了自己,还要保证子树剩余的满足题目意思。就把子树中图片说明中拿出最小的累加进这个根节点的图片说明的花费。

其余的,根据

for(int i=1; i<m;++i){  //自己满足最长小于等于i,不删除连接子节点的链
            for(int j=0; j<m; ++j){ // 孩子为根节点的子树最长链满足小于等于j
                if(i+j<m)
                    //    这个时候整个u最长就是子树剩余+连接根节点   or  指定根节点剩余的最长链  取较大者
                    //    对应花费累加求最小
                    f[u][max(i,j+1)]=min(f[u][max(i,j+1)],tmp[i]+f[it][j]);
                else
                    break;
            }
        }

https://ac.nowcoder.com/acm/contest/view-submission?submissionId=43642924

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read(){
    int s = 0, w = 1; char ch = getchar();
    while (ch < 48 || ch > 57) { if (ch == '-') w = -1; ch = getchar(); }
    while (ch >= 48 && ch <= 57) s = (s << 1) + (s << 3) + (ch ^ 48), ch = getchar();
    return s * w;
}

const int N = 5e3+7;
const int INF = 0x3f3f3f3f;
int a[N],f[N][N],tmp[N];
// f[i][j] 表示以i为根的子树,最长链小于等于j的最小值

vector<int> e[N];
int n,m;

void dfs(int u,int fa){
    f[u][0]=a[u]; //吧自己删除
    for(auto it : e[u]){
        if(it==fa)    continue;
        dfs(it,u);
        f[u][0]+=*min_element(f[it],f[it]+m); //删除连接子节点的链,还要满足子树中符合题目条件
        for(int i=1; i<m; ++i)
            tmp[i]=f[u][i],f[u][i]=INF; //可能多次更新,下面要求最小值,控制成无穷大
        for(int i=1; i<m;++i){  //自己满足最长小于等于i,不删除连接子节点的链
            for(int j=0; j<m; ++j){ // 孩子为根节点的子树最长链满足小于等于j
                if(i+j<m)
                    f[u][max(i,j+1)]=min(f[u][max(i,j+1)],tmp[i]+f[it][j]);  
                else
                    break;
            }
        }
    }
}

int main(){
    n=read(),m=read();
    for(int i=1;i<=n;++i)    a[i]=read();
    for(int i=1;i<n;++i){
        int u=read(),v=read();
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs(1,0);
    printf("%d\n",*min_element(f[1],f[1]+m)); //min_element和max_element 返回区间最小或者最大
    return 0;
}