Tree
题目大意
给一棵树,让选择一个联通的子图
满足:
度数大于k的点的数量不超过1个(小于等于1)
图是联通的。
边权和最大。
问能选出来的最大边权和是多少。
题解
树形dp
dp[x][0] 表示x的子树里的点的度数都小于等于k的边权和最大值。
dp[x][1] 表示x的子树里有一个点或0个点的度数大于k的边权的最大值。
转移方程:
dp[x][0]显然是儿子中最大的k - 1个dp[v][0]的和。
dp[x][1]怎么算?
有两种情况:
当这个度数大于k的点是x的时候,dp[x][1]是所有儿子的dp[v][0]的和。
当这个度数大于k的点不是x的时候,选k - 2个儿子的dp[v][0] 一个儿子的 dp[v][1] 就好了,但是要保证最大,可以先按dp[v][0]排个序.
假设选的都是前k - 1个儿子的时候。可以先把k - 1个儿子的dp[v][0]加起来,然后把一个dp[v][0]换成dp[v][1]。
不是前k - 1个儿子的时候,可以先选前k - 2个,然后再在后面选一个最大的dp[v][1]就好了。
为什么是k - 1个呢? 因为还有一个跟他父亲相连,当然还有不跟他父亲相连的情况。不跟他父亲相连的情况不能算到dp[x][1]里面,所以这种情况直接统计答案就好了。
md 写的时候居多bug,难受了一晚上+一早上。
记得特判0的时候 不能写return 0; 多组样例 要写continue;!!!服了,我是猪
dp的时候还得特判一下1
#include<algorithm>
#include<iostream>
#include <cstdio>
#include <string>
#include <queue>
#include <cstring>
#include <stack>
#include <set>
#include <map>
using namespace std;
typedef long long ll;
typedef pair<int,ll> pii;
typedef pair<ll,ll> pll;
typedef pair<double,double> pdd;
typedef unsigned long long ull;
typedef set<int>::iterator sit;
#define st first
#define sd second
#define mkp make_pair
#define pb push_back
void tempwj(){
freopen("hash.in","r",stdin);freopen("hash.out","w",stdout);}
ll gcd(ll a,ll b){
return b == 0 ? a : gcd(b,a % b);}
ll qpow(ll a,ll b,ll mod){
a %= mod;ll ans = 1;while(b){
if(b & 1)ans = ans * a % mod;a = a * a % mod;b >>= 1;}return ans%mod;}
struct cmp{
bool operator()(const pii & a, const pii & b){
return a.second < b.second;}};
int lb(int x){
return x & -x;}
//friend bool operator < (Node a,Node b) 重载
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 998244353;
const int maxn = 2e5+10;
const int M = 3e7+5;
ll dp[maxn][2];
pll pp[maxn];
bool cmp3(pll a,pll b)
{
if(a.st != b.st)
return a.st > b.st;
return a.sd > b.sd;
}
int k;
ll ans = 0;
std::vector<pii> vv[maxn];
ll per[maxn];
ll nex[maxn];
void dfs(int x,int fa)
{
for (int i = 0; i < vv[x].size(); i ++ )
{
int v = vv[x][i].st;
if(v == fa)
continue;
dfs(v,x);
}
int cnt =0 ;
for (int i =0 ; i < vv[x].size(); i ++ )
{
int v= vv[x][i].st;
if(v == fa)
continue;
cnt ++ ;
pp[cnt].st = dp[v][0] + vv[x][i].sd;
pp[cnt].sd = dp[v][1] + vv[x][i].sd;
}
sort(pp + 1, pp + 1 + cnt,cmp3);
per[0] =0 ;
for (int i= 1; i <= cnt; i ++ )
{
per[i] = max(per[i - 1], pp[i].sd - pp[i].st);
}
nex[cnt + 1] = 0;
for (int i= cnt; i >= 1; i -- )
{
nex[i] = max(nex[i + 1], pp[i].sd);
}
for (int i = 1; i <= min(k - 1,cnt); i ++ )
{
dp[x][0] += pp[i].st;
}
//x点是那个大于k的
for (int i = 1; i <= cnt; i ++ )
{
dp[x][1] += pp[i].st;
}
if(k > 1)
{
ll s = 0;
for (int i = 1; i <= min(k - 1,cnt); i ++ )
{
s += pp[i].st;
}
int kk = min(k - 1,cnt);
dp[x][1] = max(dp[x][1],s + per[kk]);
dp[x][1] = max(dp[x][1], s - pp[kk].st + nex[kk + 1]);
}
ans = max(ans,dp[x][1]);
ll sum = 0;
int kk = min(k,cnt);
for (int i= 1; i <= kk; i ++ )
{
sum += pp[i].st;
}
ans = max(ans, sum + per[kk]);
ans = max(ans, sum - pp[kk].st + nex[kk + 1]);
}
int main()
{
int T;
scanf("%d",&T);
while(T -- )
{
ans = 0;
int n;
scanf("%d%d",&n,&k);
for (int i = 1; i <= n; i ++ )
{
dp[i][0] = dp[i][1] = 0;
vv[i].clear();
}
for (int i = 1; i < n; i ++ )
{
int x,y,v;
scanf("%d%d%d",&x,&y,&v);
vv[x].pb(mkp(y,v));
vv[y].pb(mkp(x,v));
}
if(k == 0)
{
printf("0\n");
continue;
}
dfs(1,0);
printf("%lld\n", ans);
}
}