图片说明

思路:树上求方案,dp[i][j]表示以i的子树里连通块个数为j的方案数(j<=k)这里我们有2个选择,

1.删除当前边,此时节点u与他的的子节点v所能构成的方案数就是dp[u][x]乘上dp[v][1——x]

2.不删除当前边,此时节点u与节点v构成i+j(i为u的连通块个数,j为v的连通块个数)个连通块转移方程是,这里我们需要用sz[u]记录一下u节点的子节点个数,i+j<=k && i<=min(sz[u],k) j<=min(sz[v],k),最后遍历一累加dp[1][1——k]就是答案。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <set>
#include<iostream>
#include<vector>
#include<queue>
//#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define SIS std::ios::sync_with_stdio(false)
#define space putchar(' ')
#define enter putchar('\n')
#define lson root<<1
#define rson root<<1|1
typedef pair<int,int> PII;
const int mod=998244353;
const int N=2e6+10;
const int M=2e3+10;
const int inf=0x7f7f7f7f;
const int maxx=2e5+7;

ll gcd(ll a,ll b)
{
    return b==0?a:gcd(b,a%b);
}

ll lcm(ll a,ll b)
{
    return a*(b/gcd(a,b));
}

template <class T>
void read(T &x)
{
    char c;
    bool op = 0;
    while(c = getchar(), c < '0' || c > '9')
        if(c == '-')
            op = 1;
    x = c - '0';
    while(c = getchar(), c >= '0' && c <= '9')
        x = x * 10 + c - '0';
    if(op)
        x = -x;
}
template <class T>
void write(T x)
{
    if(x < 0)
        x = -x, putchar('-');
    if(x >= 10)
        write(x / 10);
    putchar('0' + x % 10);
}
ll qsm(int a,int b,int p)
{
    ll res=1%p;
    while(b)
    {
        if(b&1) res=res*a%p;
        a=1ll*a*a%p;
        b>>=1;
    }
    return res;
}

struct node
{
    int v,nex;
}edge[N];
int cnt=0,head[2020];
ll sz[2020];
ll dp[2020][2020];
ll n,k;
void add(int u,int v)
{
    edge[++cnt].v=v;
    edge[cnt].nex=head[u];
    head[u]=cnt;
}

void dfs(int u,int fa)
{
    dp[u][1]=sz[u]=1;
    for(int i=head[u];~i;i=edge[i].nex)
    {
        int v=edge[i].v;
        if(v==fa)continue;
        dfs(v,u);
        ll sum=0;
        for(int j=1;j<=min(sz[v],k);j++)
        {
            sum=(sum+dp[v][j])%mod;
        }
        for(int j=min(k,sz[u]);j;j--)
        {
            for(int p=min(k,sz[v]);p;p--)
            {
                if(j+p<=k) dp[u][j+p]=(dp[u][j+p]+dp[u][j]*dp[v][p])%mod;//不删边

            }
            dp[u][j]=dp[u][j]*sum%mod;//删除这个边
        }

        sz[u]+=sz[v];
    }

}
int main()
{
  SIS;

  cin>>n>>k;
  memset(head,-1,sizeof head);
  for(int i=0;i<n-1;i++)
  {
      int u,v;
      cin>>u>>v;
      add(u,v);
      add(v,u);
  }
  dfs(1,0);
  ll ans=0;
  for(int i=1;i<=k;i++)
    ans=(ans+dp[1][i])%mod;
    cout<<ans<<endl;


    return 0;
}