题目描述
We have n empty boxes, so let’s recolor those boxes with m colors.
The boxes are put in a line. It is not allowed to color any adjacent boxes with the same color. Boxes i and i+1 are said to be adjacent for every i,1≤i≤n.
And we also want the total number of different colors of the n boxes being exactly k.
Two ways are considered different if and only if there is at least one box being colored with different colors.
输入描述:
The first line of the input contains integer T(1≤T≤100) -the number of the test cases
For each case: there will be one line, which contains three integers n,m,k(1≤n,m≤109 1≤k≤106, k≤n,m).
题解
n个空盒子,m种颜色,然后让你选择k种颜色,染完这n个盒子,并且所用的颜色种类数刚好为k个,且相邻盒子的颜色不相同,问你有多少种方案。如果求最多k种颜色染完所有盒子的方案数的话很容易想到是图片说明
如果求刚好用了k种颜色,因此我们考虑容斥原理。
刚好用了k种颜色的方案数就是小于等于k的-小于等于k-1的
对于当前选择的颜色数小于等于i的情况为
图片说明
当然这里面会有重复,所以要用容斥定理去重
去重后选择的颜色数小于等于i的情况为
图片说明
图片说明
就是减去偶数个集合交集加上奇数集合交集
答案就是
图片说明
式子前边有个C(m,k),m有1e9这么大,大组合数显然是行不通的,但是由于k不大,利用公式化简后可以计算
代码

#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define lowbit(x) x&(-x)

typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll, ll> pll;

const int N = 1e6+5;
const ll mod = 1e9+7;
const int INF = 0x3f3f3f3f;
const double eps =1e-9;
const double PI=acos(-1.0);
const int dir[8][2]={-1,0,1,0,0,-1,0,1,1,1,1,-1,-1,1,-1,-1};

ll qpow(ll x,ll y){
    ll ans=1,t=x;
    while(y>0){
        if(y&1)ans*=t,ans%=mod;
        t*=t,t%=mod;
        y>>=1;
    }
    return ans%mod;
}
ll jc[N],ny[N];
void init(){
    jc[0]=1;ny[0]=1;
    for(int i=1;i<=N-1;i++){
        jc[i]=jc[i-1]*i%mod;
        ny[i]=ny[i-1]*qpow(i,mod-2)%mod;
    }
}
ll C(ll a,ll b){return jc[a]*ny[b]%mod*ny[a-b]%mod;}
void solve(){
    ll n,m,k;
    cin>>n>>m>>k;
    ll ans = 1;
    ans=ans*k%mod*qpow(k-1,n-1)%mod;
    for(int i=1;i<k;i++){
        ans+=((i&1?-1:1)*C(k,k-i)*(k-i)%mod*qpow(k-i-1,n-1)+mod)%mod;
    }
    for(ll i=m-k+1;i<=m;i++)ans=ans*i%mod;
    ans=ans*ny[k]%mod;
    cout<<(ans+mod)%mod;
}

int main(){
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    init();
    int t;cin>>t;
    while(t--)solve(),cout<<'\n';
    //solve();
    return 0;
}