题目描述

“你,你认错人了。我真的,真的不是食人魔。”--蓝魔法师
给出一棵树,求有多少种删边方案,使得删后的图每个连通块大小小于等于k,两种方案不同当且仅当存在一条边在一个方案中被删除,而在另一个方案中未被删除,答案对998244353取模

输入描述:

第一行两个整数n,k, 表示点数和限制
2 <= n <= 2000, 1 <= k <= 2000
接下来n-1行,每行包括两个整数u,v,表示u,v两点之间有一条无向边
保证初始图联通且合法

输出描述:

共一行,一个整数表示方案数对998244353取模的结果

题解

为什么拖了好几天才写,dp杀我呜呜呜

代码

#include<iostream>
#include<algorithm>
#include<map>
#include<vector>
#include<set>
#include<string>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
#include<stack>
using namespace std;
#define ll long long
#define ull unsigned long long
#define pb push_back
#define pii pair<int,int>
#define all(A) A.begin(), A.end()
#define fi first
#define se second
#define MP make_pair
#define rep(i,n) for(register int i=0;i<(n);++i)
#define repi(i,a,b) for(register int i=int(a);i<=(b);++i)
#define repr(i,b,a) for(register int i=int(b);i>=(a);--i)
template<typename T>
inline T read(){
    T s=0,f=1; char ch = getchar();
    while(!isdigit(ch)) {if(ch == '-') f=-1;ch = getchar();}
    while(isdigit(ch)) {s=(s<<3)+(s<<1)+ch-48;ch = getchar();}
    return s*f;
}
#define gn() read<int>()
#define gl() read<ll>()
template<typename T>
inline void print(T x) {
    if(x<0) putchar('-'), x=-x;
    if(x>9) print(x/10);
    putchar(x%10+'0');
}
////////////////////////////////////////////////////////////////////////
const int N=2050;
const int mod=998244353;
vector<int> v[N];
ll dp[N][N];
int n,k,siz[N];
void dfs(int node,int fa){
    siz[node]=1;
    dp[node][1]=1;
    for(auto to:v[node]){
        if(to==fa)continue;
        dfs(to,node);
        ll s=0;
        for(int i=1;i<=min(siz[to],k);++i){
            s=(s+dp[to][i])%mod;
        }
        for(int i=min(k,siz[node]);i;--i){
            for(int j=min(k,siz[to]);j;--j){
                if(i+j>k)continue;
                dp[node][i+j]=(dp[node][i+j]+dp[node][i]*dp[to][j])%mod;
            }
            dp[node][i]=dp[node][i]*s%mod;
        }
        siz[node]+=siz[to];
    }
}
////////////////////////////////////////////////////////////////////////
int main(){
    n=gn(),k=gn();
    repi(i,2,n){
        int x=gn(),y=gn();
        v[x].pb(y);
        v[y].pb(x);
    }
    dfs(1,0);
    ll ans=0;
    repi(i,1,k)ans=(ans+dp[1][i])%mod;
    print(ans);
}
/**
* In every life we have some trouble
* When you worry you make it double
* Don't worry,be happy.
**/