Tree

题目链接

题目大意

给一棵树,让选择一个联通的子图
满足:
度数大于k的点的数量不超过1个(小于等于1)
图是联通的。
边权和最大。
问能选出来的最大边权和是多少。

题解

树形dp
dp[x][0] 表示x的子树里的点的度数都小于等于k的边权和最大值。
dp[x][1] 表示x的子树里有一个点或0个点的度数大于k的边权的最大值。
转移方程:
dp[x][0]显然是儿子中最大的k - 1个dp[v][0]的和。
dp[x][1]怎么算?
有两种情况:
当这个度数大于k的点是x的时候,dp[x][1]是所有儿子的dp[v][0]的和。
当这个度数大于k的点不是x的时候,选k - 2个儿子的dp[v][0] 一个儿子的 dp[v][1] 就好了,但是要保证最大,可以先按dp[v][0]排个序.
假设选的都是前k - 1个儿子的时候。可以先把k - 1个儿子的dp[v][0]加起来,然后把一个dp[v][0]换成dp[v][1]。
不是前k - 1个儿子的时候,可以先选前k - 2个,然后再在后面选一个最大的dp[v][1]就好了。

为什么是k - 1个呢? 因为还有一个跟他父亲相连,当然还有不跟他父亲相连的情况。不跟他父亲相连的情况不能算到dp[x][1]里面,所以这种情况直接统计答案就好了。

md 写的时候居多bug,难受了一晚上+一早上。
记得特判0的时候 不能写return 0; 多组样例 要写continue;!!!服了,我是猪
dp的时候还得特判一下1

#include<algorithm>
#include<iostream>
#include <cstdio>
#include <string>
#include <queue>
#include <cstring>
#include <stack>
#include <set>
#include <map>
using namespace std;
typedef long long ll;
typedef pair<int,ll> pii;
typedef pair<ll,ll> pll;
typedef pair<double,double> pdd;
typedef unsigned long long ull;
typedef set<int>::iterator sit;
#define st first
#define sd second
#define mkp make_pair
#define pb push_back
void tempwj(){
   freopen("hash.in","r",stdin);freopen("hash.out","w",stdout);}
ll gcd(ll a,ll b){
   return b == 0 ? a : gcd(b,a % b);}
ll qpow(ll a,ll b,ll mod){
   a %= mod;ll ans = 1;while(b){
   if(b & 1)ans = ans * a % mod;a = a * a % mod;b >>= 1;}return ans%mod;}
struct cmp{
   bool operator()(const pii & a, const pii & b){
   return a.second < b.second;}};
int lb(int x){
   return  x & -x;}
//friend bool operator < (Node a,Node b) 重载
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 998244353;
const int maxn = 2e5+10;
const int M = 3e7+5;
ll dp[maxn][2];
pll pp[maxn];
bool cmp3(pll a,pll b)
{
   
    if(a.st != b.st)
        return a.st > b.st;
    return a.sd > b.sd;
}
int k;
ll ans = 0;
std::vector<pii> vv[maxn];
ll per[maxn];
ll nex[maxn];
void dfs(int x,int fa)
{
   
    for (int i = 0; i < vv[x].size(); i ++ )
    {
   
        int v = vv[x][i].st;
        if(v == fa)
            continue;
        dfs(v,x);
    }
    int cnt =0 ;
    for (int i =0 ; i < vv[x].size(); i ++ )
    {
   
        int v=  vv[x][i].st;
        if(v == fa)
            continue;
        cnt ++ ;
        pp[cnt].st = dp[v][0] + vv[x][i].sd;
        pp[cnt].sd = dp[v][1] + vv[x][i].sd;
    }
    sort(pp + 1, pp + 1 + cnt,cmp3);

    per[0] =0 ;
    for (int  i= 1; i <= cnt; i ++ )
    {
   
        per[i] = max(per[i - 1], pp[i].sd - pp[i].st);
    }
    nex[cnt + 1] = 0;
    for (int  i= cnt; i >= 1; i -- )
    {
   
        nex[i] = max(nex[i + 1], pp[i].sd);
    }

    for (int i = 1; i <= min(k - 1,cnt); i ++ )
    {
   
        dp[x][0] += pp[i].st;
    }
    //x点是那个大于k的
    for (int i = 1; i <= cnt; i ++ )
    {
   
        dp[x][1] += pp[i].st;
    }
    
    if(k > 1)
    {
   
        ll s = 0;
        for (int i = 1; i <= min(k - 1,cnt); i ++ )
        {
   
            s += pp[i].st;
        }
        int kk = min(k - 1,cnt);
        dp[x][1] = max(dp[x][1],s + per[kk]);
        dp[x][1] = max(dp[x][1], s - pp[kk].st + nex[kk + 1]);
    }
    ans = max(ans,dp[x][1]);
    ll sum = 0;
    int kk = min(k,cnt);
    for (int  i= 1; i <= kk; i ++ )
    {
   
        sum += pp[i].st;
    }
    ans = max(ans, sum + per[kk]);
    ans = max(ans, sum - pp[kk].st + nex[kk + 1]);
}



int main()
{
   
    int T;
    scanf("%d",&T);
    while(T -- )
    {
   
        ans = 0;
        int n;
        scanf("%d%d",&n,&k);
        for (int i = 1; i <= n; i ++ )
        {
   
            dp[i][0] = dp[i][1] = 0;
            vv[i].clear();
        }
        for (int i = 1; i < n; i ++ )
        {
   
            int x,y,v;
            scanf("%d%d%d",&x,&y,&v);
            vv[x].pb(mkp(y,v));
            vv[y].pb(mkp(x,v));
        }
        if(k == 0)
        {
   
            printf("0\n");
            continue;
        }
        dfs(1,0);
        printf("%lld\n", ans);
    }
}