题号 NC19996
名称 [HAOI2015]树上染色
来源 [HAOI2015]
每日一题三期汇总贴~

题目描述

有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并将其他的N-K个点染成白色。

将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。问收益最大值是多少。

样例

输入
5 2
1 2 3
1 5 1
2 3 1
2 4 2
输出
17
【样例解释】
将点1,2染黑就能获得最大收益。

算法

(树形背包问题 + 计算边的贡献)

这种给点染色然后询问距离的问题的一个常见套路就是计算每条边的贡献

每一条边都可以将点集分成左右两部

给边定义一个方向,枚举边下方点集(这样符合dfs的访问顺序)中黑点的个数那么剩余的黑点就都在上面的点集中了

那么就能得到条边的可能贡献:(假设下面点集中的黑点个数为t,sz表示下面点集的大小)

接着思考dp部分:

定义表示以u为根的树中统计前j个子树有个点被染成黑色的边的最大贡献是多少

状态计算:

根据最后一个统计的子树有多少个点被染成黑色来划分


我们发现每次只会用到,上一层的信息类似于01背包的优化方式(滚动数组)

我们可以优化掉一维

和01背包同样的问题如果从小到大枚举,转移的信息会被覆盖

所以l需要从大到小枚举


补充:

对初值的解释:

  1. 是以u为根的子树中没有一个节点被染色了所以u为根的子树对答案的贡献为0

  2. 我们在计算时没有考虑将节点染成黑色对答案的贡献,但这种情况对的父节点的答案是有贡献的。由于我们只考虑往下的边的贡献,而只会对往上的边有贡献所以

时间复杂度

(时间复杂度有点玄学,不太会分析)

C++ 代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
//#include <unordered_map>
#include <map>
#include <vector>
#include <queue>
#include <set>
#include <bitset>
#include <cmath>

#define P 131

#define lc u << 1
#define rc u << 1 | 1

using namespace std;
typedef long long LL;
const int N = 2010;
int h[N],ne[N * 2],e[N * 2],w[N * 2],idx;
int sz[N];
LL f[N][N];
int n,m;

void add(int a,int b,int c)
{
    e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx ++;
}

void dfs(int u,int fa)
{
    f[u][0] = f[u][1] = 0;
    sz[u] = 1;
    for(int i = h[u];~i;i = ne[i])
    {
        int son = e[i];
        if(son == fa) continue;
        dfs(son,u);
        sz[u] += sz[son];
        for(int j = min(m,sz[u]);j >= 0;j --)
            for(int k = 0;k <= min(j,sz[son]);k ++)
            {
                f[u][j] = max(f[u][j],f[u][j - k] + f[son][k] + 1ll * w[i] * 
                ( k * (m - k) + (sz[son] - k) * (n - m - (sz[son] - k)) ) );
            }
    }
}

void solve()
{
    scanf("%d%d",&n,&m);
    memset(h,-1,sizeof h);
    memset(f,-0x3f,sizeof f);
    for(int i = 0;i < n - 1;i ++)
    {
        int a,b,c;
        scanf("%d%d%d",&a,&b,&c);
        add(a,b,c);
        add(b,a,c);
    }
    dfs(1,-1);
    printf("%lld\n",f[1][m]);
}

int main()
{
    /*#ifdef LOCAL
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    #else
    #endif // LOCAL*/
    int T = 1;
    // init(N - 1);
    // scanf("%d",&T);
    while(T --)
    {
        // scanf("%lld%lld",&n,&m);
        solve();
        // test();
    }
    return 0;
}