题目描述:
shy有一颗树,树有n个结点。有k种不同颜色的染料给树染色。一个染色方案是合法的,当且仅当对于所有相同颜色的点对(x,y),x到y的路径上的所有点的颜色都要与x和y相同。请统计方案数。

输入描述:
第一行两个整数n,k代表点数和颜色数;
接下来n-1行,每行两个整数x,y表示x与y之间存在一条边;

输出描述:
输出一个整数表示方案数(mod 1e9+7)。

示例1
输入
4 3
1 2
2 3
2 4
输出
39

备注:
对于30%的数据,n≤10, k≤3;
对于100%的数据,n,k≤300。

思路:
组合数学做法:
从[1,k]枚举所有联通块(联通块满足要求),然后求出每种联通块能够颜色的方案数量。
怎么枚举所有联通块?
从边的角度来考虑,我们每次枚举 i 个联通块,就会少i - 1 条边,那么从n - 1条边删掉i - 1条边就是它的方案数量。这个方案数与顺序无关所以是C(n - 1, i - 1),然后每种方案数用k 种颜色去染 i 个联通块就是排列问题了,所以最终答案就是 <munderover> i = 1 k </munderover> C ( n 1 , i 1 ) A ( k , i ) \sum_{i=1}^{k}C(n - 1, i - 1)A(k , i) i=1kC(n1,i1)A(k,i)
时间复杂度O(n)。
代码:

#include<bits/stdc++.h>
using namespace std;

typedef long long int ll;
ll mod = 1e9 + 7;
ll qp(ll a,ll b, ll p){ll ans = 1;while(b){if(b&1){ans = (ans*a)%p;--b;}a = (a*a)%p;b >>= 1;}return ans%p;}
ll Inv(ll x)          { return qp(x,mod-2,mod);}
ll C(ll n,ll m){if (m>n) return 0;ll ans = 1;for (int i = 1; i <= m; ++i) ans=ans*Inv(i)%mod*(n-i+1)%mod;return ans%mod;}
ll A(ll n,ll m,ll mod){ll sum=1; for(int i=n;i>=n-m+1;i--) sum=(sum*i)%mod; return sum%mod;}

void solved(){
	ll n,k;cin>>n>>k;
	ll ans = 0;
	for(int i = 1; i <= k; i++){
		ans += (C(n - 1, i - 1)%mod * A(k,i,mod) % mod)%mod;
	}
	cout<<ans % mod<<endl;
}
int main(){
	solved();
	return 0;
}

DP做法:
一开始看的这个题,感觉不知道怎么搞,一开始想的是dfs暴力搞一下,但是要检查(u,v)颜色是不是相同就感觉写不出来。。。
然后看了一下题解是dp,定义
dp[i][j]:前i个节点从k种颜色中取j种颜色取染色的方案数。(一开始以为从前i个节点用j种颜色。。)
转移方程:dp[i][j] = dp[i - 1][j] + dp[i - 1][j - 1] * (k - (j - 1))
大概意思是:考虑第i个节点从k种颜色选j种的方案数会等于跟它父亲节点颜色保持一致(满足条件),或者是用新的颜色那么就是原来父亲节点用了的颜色数量 * (总颜色数量 - 父亲节点用了的数量)这样相加就是dp[i][j]的数量了。

我感觉这题好像可以用组合数学的方法做,先想一想,想出来了再更新。

#include<bits/stdc++.h>
using namespace std;
 
const long long int mod = 1000000000 + 7;
long long int  dp[10000][10000];
//dp[i][j]:前i个节点从k种颜色选择j种染树的方案数量
 
void solved(){
    long long int n,k;cin>>n>>k;
    dp[0][0] = 1;
    for(int i = 1; i <= n; i++){
        for(int j = 1; j <= k; j++){
            dp[i][j] = dp[i - 1][j] + dp[i - 1][j - 1] * (k - (j - 1));
            dp[i][j] %= mod;
        }
    }
    long long int ans = 0;
    for(int i = 1; i <= k; i++){
        ans += dp[n][i];
        ans %= mod;
    }
    cout<<ans;
}
int main(){
    solved();
    return 0;
}