solution

仔细思考一下题意,发现其实就是要求将树分成不超过K个连通块,然后给每个连通块分配一种颜色。

假设连通块个数为x,那么根据乘法原理,给每个连通块分配不同颜色的方案数就是

然后问题就只剩下了如何求将树分成不超过K个连通块的方案数。

这显然可以用树形dp来求。

用f[u][i]表示将u这课子树划分为i个连通块的方案数。然后枚举一个u的子树v,同时枚举v子树中的连通块个数j。如果将v所在的连通块与u的颜色设为一样那么就会转移到,如果设为不一样,那么就会转移到

这个dp过程的复杂度表面上是的,实际上是的。

只要在枚举连通块个数的时候,上限不超过子树大小即可。

原因我们可以考虑连边,也就是如果合并两个大小分别为s1和s2的子树,看做在这两棵子树每个点之间两两连边。显然,任意两个点之间最多会连1条边。所以总共最多会连条边。所以总复杂度就是

所以此题的数据范围其实还可以加强。

code

/*
* @Author: wxyww
* @Date: 2020-04-06 10:45:15
* @Last Modified time: 2020-04-06 10:58:24
*/
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cmath>
#include<ctime>
#include<bitset>
#include<cstring>
#include<algorithm>
#include<string>
#include<queue>
#include<vector>
using namespace std;
typedef long long ll;
const int N = 310,mod = 1e9 + 7;
ll read() {
    ll x=0,f=1;char c=getchar();
    while(c<'0'||c>'9') {
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9') {
        x=x*10+c-'0';
        c=getchar();
    }
    return x*f;
}
struct node {
    int v,nxt;
}e[N << 1];
int head[N],ejs;
void add(int u,int v) {
    e[++ejs].v = v;e[ejs].nxt = head[u];head[u] = ejs;
}
int f[N][N],siz[N],tmp[N];
void dfs(int u,int fa) {
    siz[u] = 1;
    f[u][1] = 1;
    for(int i = head[u];i;i = e[i].nxt) {
        int v = e[i].v;
        if(v == fa) continue;
        dfs(v,u);

        memset(tmp,0,sizeof(tmp));
        for(int t1 = siz[u];t1 >= 1;--t1) {
            for(int t2 = siz[v];t2 >= 1;--t2) {
                tmp[t1 + t2] += 1ll * f[u][t1] * f[v][t2] % mod;
                tmp[t1 + t2] >= mod ? tmp[t1 + t2] -= mod : 0;
                tmp[t1 + t2 - 1] += 1ll * f[u][t1] * f[v][t2] % mod;
                tmp[t1 + t2 - 1] >= mod ? tmp[t1 + t2 - 1] -= mod : 0;
            }
        }
        siz[u] += siz[v];
        for(int j = 1;j <= siz[u];++j) f[u][j] = tmp[j];
    }
}
int jc[N];
int main() {
    int n = read(),K = read();
    for(int i = 1;i < n;++i) {
        int u = read(),v = read();
        add(u,v);add(v,u);
    }
    jc[0] = 1;
    for(int i = 1;i <= K;++i) jc[i] = 1ll * jc[i - 1] * (K - i + 1) % mod;
    dfs(1,0);
    ll ans = 0;
    for(int i = 1;i <= K;++i) {
        ans += 1ll * f[1][i] * jc[i] % mod;
        ans >= mod ? ans -= mod : 0;
    }
    cout<<ans;

    return 0;
}