题目链接

题面:

题意:
给定一棵树,边有边权,找到一个联通块符合以下条件:
(1)至多有一个节点的度数大于k。
(2)边权和最大。
输出最大的边权和。

题解:
考虑 d p dp dp
我们设 d p [ x ] [ 3 ] dp[x][3] dp[x][3]

(1) d p [ x ] [ 0 ] dp[x][0] dp[x][0] 表示以 x x x 为根的子树内,所选的连通块节点的儿子个数都小于等于 k 1 k-1 k1 的最大值。那我我们对 x x x 的儿子节点的 d p [ y i ] [ 0 ] + e d g e ( x , y i ) dp[y_i][0]+edge(x,y_i) dp[yi][0]+edge(x,yi) 从大到小排序后, d p [ x ] [ 0 ] = i = 1 k 1 ( d p [ y i ] [ 0 ] + e d g e ( x , y i ) ) dp[x][0]=\sum_{i=1}^{k-1}(dp[y_i][0]+edge(x,y_i)) dp[x][0]=i=1k1(dp[yi][0]+edge(x,yi))

(2) d p [ x ] [ 1 ] dp[x][1] dp[x][1] 表示以 x x x 为根的子树内,所选的联通块节点至多有一个节点的儿子节点的个数 大于等于 k k k,其余的都小于 k k k 的最大值。

这个儿子节点数大于等于 k k k 的节点,可以是当前节点 x x x,那么 d p [ x ] [ 1 ] = i = 1 c n t s o n ( d p [ y i ] [ 0 ] + e d g e ( x , y i ) ) dp[x][1]=\sum_{i=1}^{cntson}(dp[y_i][0]+edge(x,y_i)) dp[x][1]=i=1cntson(dp[yi][0]+edge(x,yi))

也可以是 x x x 的儿子孙子节点。 d p [ x ] [ 1 ] dp[x][1] dp[x][1] 应该由 k 2 k-2 k2 d p [ y i ] [ 0 ] dp[y_i][0] dp[yi][0] 和 一个 d p [ y i ] [ 1 ] dp[y_i][1] dp[yi][1] 转移而来,其中某个 y i y_i yi 只能贡献 d p [ y i ] [ 0 ] , d p [ y i ] [ 1 ] dp[y_i][0],dp[y_i][1] dp[yi][0],dp[yi][1] 中的一个。枚举选哪个 d p [ y i ] [ 1 ] dp[y_i][1] dp[yi][1],然后选剩下的前 k 2 k-2 k2 大的 d p [ y i ] [ 0 ] + e g d e ( x , y i ) dp[y_i][0]+egde(x,y_i) dp[yi][0]+egde(x,yi)即可。在这种情况下, x x x 节点只能有 k 1 k-1 k1 个儿子。

(3) d p [ x ] [ 2 ] dp[x][2] dp[x][2]表示以 x x x 为根的子树内,除了 x x x 以外所选的连通块节点至多有一个节点的儿子节点的个数大于等于 k k k,其余的都小于 k k k 且当前节点 x x x 选择 k k k 个儿子的最大值。

这样 d p [ x ] [ 2 ] dp[x][2] dp[x][2] 表示 x x x k k k 个儿子,但是 x x x 的子树内所选的连通块有至多一个节点的儿子数大于等于 k k k,也就是说其子树内所选的连通块中,是有一个节点的度大于 k k k 的,那么 x x x 节点的度不能再大于 k k k,即 d p [ x ] [ 2 ] dp[x][2] dp[x][2] 已经算是独立答案,不再往 x x x的父亲节点上传。在这种情况下, x x x 节点有 k k k 个儿子。

d p [ x ] [ 2 ] dp[x][2] dp[x][2] 应该由 k 1 k-1 k1 d p [ y i ] [ 0 ] dp[y_i][0] dp[yi][0] 和 一个 d p [ y i ] [ 1 ] dp[y_i][1] dp[yi][1] 转移而来,其中某个 y i y_i yi 只能贡献 d p [ y i ] [ 0 ] , d p [ y i ] [ 1 ] dp[y_i][0],dp[y_i][1] dp[yi][0],dp[yi][1] 中的一个。

注意: k = 0 k=0 k=0 时, a n s = 0 ans=0 ans=0

代码:

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#include<list>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
//#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
using namespace std;

const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=998244353;
const double eps=1e-1;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=400100;
const int maxp=1100;
const int maxm=300100;
const int up=1000;

int head[maxn],ver[maxn],edge[maxn],nt[maxn],tot=1;
ll dp[maxn][3],maxx=0;
int n,k;

void add(int x,int y,int z)
{
    ver[++tot]=y,edge[tot]=z;
    nt[tot]=head[x],head[x]=tot;
}

void init(int n)
{
    for(int i=1;i<=n;i++)
        head[i]=dp[i][0]=dp[i][1]=dp[i][2]=0;
    maxx=0;
    tot=1;
}

bool cmp(const pair<int,ll> &a,const pair<int,ll> &b)
{
    return dp[a.first][0]+a.second>dp[b.first][0]+b.second;
}

void dfs(int x,int fa)
{
    if(k==0) return ;
    vector<pair<int,ll> >p;
    p.clear();
    p.pb(pr(0,0));

    for(int i=head[x];i;i=nt[i])
    {
        int y=ver[i],z=edge[i];
        if(y==fa) continue;
        dfs(y,x);
        p.pb(pr(y,z));
    }

    int cnt=p.size()-1;
    sort(p.begin()+1,p.end(),cmp);

    //dp[x][0]
    for(int i=1;i<=k-1&&i<=cnt;i++)
        dp[x][0]+=dp[p[i].first][0]+p[i].second;

    //dp[x][1]
    //x点的儿子大于等于k
    for(int i=1;i<=cnt;i++)
        dp[x][1]+=dp[p[i].first][0]+p[i].second;
    //dp[x][1]
    //dp[y][1]在前k-1个。
    for(int i=1;i<=k-1&&i<=cnt;i++)
        dp[x][1]=max(dp[x][1],dp[x][0]+(dp[p[i].first][1]-dp[p[i].first][0]));
    //dp[x][1]
    //dp[y][1]在k---cnt取
    if(k>1) //这里x节点只能有k-1个儿子,如果k==1的话会因为k-1没有贡献而导致x的节点变成1个儿子
        for(int i=k;i<=cnt;i++)
            dp[x][1]=max(dp[x][1],dp[x][0]-(dp[p[k-1].first][0]+p[k-1].second)+dp[p[i].first][1]+p[i].second);

    //dp[x][2]
    //dp[y][1]在前k个取
    ll sumk=dp[x][0];
    if(cnt>=k) sumk+=dp[p[k].first][0]+p[k].second;
    for(int i=1;i<=k&&i<=cnt;i++)
        dp[x][2]=max(dp[x][2],sumk+(dp[p[i].first][1]-dp[p[i].first][0]));
    //dp[x][2]
    //dp[y][1]在k+1---cnt取
    for(int i=k+1;i<=cnt;i++)
        dp[x][2]=max(dp[x][2],dp[x][0]+dp[p[i].first][1]+p[i].second);

    maxx=max(max(dp[x][1],dp[x][2]),maxx);
}

int main(void)
{
    int tt;
    scanf("%d",&tt);
    while(tt--)
    {
        scanf("%d%d",&n,&k);
        init(n);

        int x,y,z;
        for(int i=1;i<n;i++)
        {
            scanf("%d%d%d",&x,&y,&z);
            add(x,y,z);
            add(y,x,z);
        }

        dfs(1,0);

        printf("%lld\n",maxx);
    }
    return 0;
}