思路:题中说明这是标准的二叉树,给我们q个枝那说明连了q+1的点,我们就可以吧边的权值用节点去存,用dp[i][j]表示以i为根节点,连了共j个节点的最大值,因为每个节点只有2个或0个子节点,那么我们枚举2个子节点可能连接的节点数,我们先dfs一次去求出每个节点的2个子节点分别是谁,并在此时给节点附上对应的权值转移方程dp[u][i]=max(dp[L][k],dp[R][i-k-1])+dp[u][1] (0<k<i)表示左儿子取k个节点,那么右儿子就取i-k-1个节点(因为还有父亲节点所以要减去1)
#include <cstdio> #include <cstring> #include <algorithm> #include <set> #include<iostream> #include<vector> #include<queue> #include<stack> #include<bits/stdc++.h> using namespace std; typedef long long ll; #define SIS std::ios::sync_with_stdio(false) #define space putchar(' ') #define enter putchar('\n') #define lson root<<1 #define rson root<<1|1 typedef pair<int,int> PII; const int mod=998244353; const int N=2e6+10; const int M=400; const int inf=0x7f7f7f7f; const int maxx=2e5+7; ll gcd(ll a,ll b) { return b==0?a:gcd(b,a%b); } ll lcm(ll a,ll b) { return a*(b/gcd(a,b)); } template <class T> void read(T &x) { char c; bool op = 0; while(c = getchar(), c < '0' || c > '9') if(c == '-') op = 1; x = c - '0'; while(c = getchar(), c >= '0' && c <= '9') x = x * 10 + c - '0'; if(op) x = -x; } template <class T> void write(T x) { if(x < 0) x = -x, putchar('-'); if(x >= 10) write(x / 10); putchar('0' + x % 10); } ll qsm(int a,int b,int p) { ll res=1%p; while(b) { if(b&1) res=res*a%p; a=1ll*a*a%p; b>>=1; } return res; } struct node { int to,nex,w; }edge[M]; struct tr { int l,r; }tre[M]; int head[M]; int tot; vector<int> G[M]; int dp[M][M]; int vis[M]; void add(int u,int v,int w) { edge[++tot].to=v; edge[tot].w=w; edge[tot].nex=head[u]; head[u]=tot; } void dfs1(int u) { vis[u]=1; for(int i=head[u];~i;i=edge[i].nex) { int v=edge[i].to; if(!vis[v]) { if(!tre[u].l) tre[u].l=v; else tre[u].r=v; dfs1(v); dp[v][1]=edge[i].w; } } } int dfs2(int u,int y) { if(dp[u][y])return dp[u][y]; if(u==0) return 0; int res=0; for(int i=0;i<y;i++) { int ans1=dfs2(tre[u].l,i); int ans2=dfs2(tre[u].r,y-i-1); res=max(ans1+ans2+dp[u][1],res); } return dp[u][y]=res; } int main() { // SIS; int n,q; memset(head,-1,sizeof head); scanf("%d%d",&n,&q); for(int i=0;i<n-1;i++) { int u,v,w; scanf("%d%d%d",&u,&v,&w); add(u,v,w); add(v,u,w); } dfs1(1); printf("%d\n",dfs2(1,q+1)); return 0; }