题目-Devu和鲜花

在这里插入图片描述

问题分析

首先假设简单情况, 如果每个盒子的花有无限个, 那么假设在盒子 1 1 1中取 x 1 x_1 x1, 在盒子 2 2 2中取 x 2 x_2 x2, 那么有
x 1 + x 2 + x 3 + . . . + x n = M x_1 + x_2 + x_3 + ... + x_n = M x1+x2+x3+...+xn=M
因为 x ≥ 0 x \ge 0 x0, 上述方程可以转化为

x 1 + x 2 + x 3 + . . . + x n = M + N , ( x ≥ 1 ) x_1 + x_2 + x_3 + ... + x_n = M + N, (x \ge 1) x1+x2+x3+...+xn=M+N,(x1)

根据隔板法, 方案数量就是 C N + M − 1 N − 1 C_{N + M - 1} ^ {N - 1} CN+M1N1

问题就变成了如何根据简单情况, 计算题目要求的 x 1 ≤ A 1 x_1 \le A_1 x1A1, x 2 ≤ A 2 x_2 \le A_2 x2A2, …, x n ≤ A n x_n \le A_n xnAn复杂情况

因为计算每种花取的上界不好计算, 可以采用补集的思想, 用总的方案数减去不合法的方案数, 就是最终答案

不合法的方案数就是至少一个箱子的花朵数量超过了 A i A_i Ai的方案数量

我们设 x 1 > A i x_1 > A_i x1>Ai的情况是 S 1 S_1 S1, x 2 > A 2 x_2 > A_2 x2>A2的情况是 S 2 S_2 S2, …, x n > A n x_n > A_n xn>An的情况是 S n S_n Sn, 那么不合法的方案数量就是

∣ S 1 ∪ S 2 . . . S n ∣ |S_1 \cup S_2 ... S_n| S1S2...Sn

根据容斥原理展开得到

∣ S 1 ∣ + ∣ S 2 ∣ + . . . + ∣ S n ∣ − ∣ S 1 ∩ S 2 ∣ − ∣ S 1 ∩ S 3 ∣ − . . . . + ∣ S 1 ∩ S 2 ∩ S 3 ∣ + . . . |S_1| + |S_2| + ... + |S_n| - |S_1 \cap S_2| - |S_1 \cap S_3| - .... + |S_1 \cap S_2 \cap S_3| + ... S1+S2+...+SnS1S2S1S3....+S1S2S3+...

现在问题就变成了, 如何求出 S i S_i Si S i ∩ S j S_i \cap S_j SiSj的情况

  • 假设情况是 x i > A i x_i > A_i xi>Ai, 那么可以先在 A i A_i Ai中取出 x i + 1 x_i + 1 xi+1个花朵, 然后剩下的情况依旧是隔板法, 方案数等于 C N + M − 1 − ( A i + 1 ) N − 1 C_{N + M - 1- (A_i + 1)} ^ {N - 1} CN+M1(Ai+1)N1
  • 假设情况是 x i > A i x_i > A_i xi>Ai x j > A j x_j > A_j xj>Aj的情况, 同理得到方案数 C N + M − 1 − ( A i + 1 ) − ( A j + 1 ) N − 1 C_{N + M - 1 - (A_i+ 1) - (A_j + 1)} ^ {N - 1} CN+M1(Ai+1)(Aj+1)N1

现在的问题就是如何计算组合数, 观察数据范围, M M M 1 0 14 10 ^ {14} 1014量级, N ≤ 20 N \le 20 N20, A i ≤ 1 0 12 A_i \le 10 ^ {12} Ai1012, 因为 N N N很小, 直接可以从定义计算, 并且因为 P P P是质数, 可以快速幂计算乘法逆元

算法步骤

组合数公式推导
C n m = n ! ( n − m ) ! m ! = n × ( n − 1 ) × ( n − 2 ) × . . . × ( n − m + 1 ) m ! C_n ^ m = \frac{n!}{(n - m)!m!} = \frac{n \times (n - 1) \times (n - 2) \times ... \times (n - m + 1)}{m!} Cnm=(nm)!m!n!=m!n×(n1)×(n2)×...×(nm+1)
算法时间复杂度 O ( N ⋅ 2 N ) O(N \cdot 2 ^ N) O(N2N)

代码实现

数据范围爆炸导致一个点无法通过

#include <bits/stdc++.h>

using namespace std;

typedef long long LL;
const int N = 25, MOD = 1e9 + 7;

int n;
LL A[N], m;

LL q_pow(LL a, LL b, int mod) {
   
    LL ans = 1;
    a %= mod;
    while (b) {
   
        if (b & 1) ans = ans % mod * a % mod;
        a = a % mod * a % mod;
        b >>= 1;
    }
    return ans;
}

LL C(int a, int b) {
   
    if (a < b || b < 0) return 0;
    LL up = 1, down = 1;
    for (LL i = a; i >= a - b + 1; --i) up = up * i % MOD;
    for (LL i = 1; i <= b; ++i) down = down * i % MOD;
    LL ans = up * q_pow(down, MOD - 2, MOD);
    return ans;
}


int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n >> m;
    for (int i = 0; i < n; ++i) cin >> A[i];
    LL ans = C(n + m - 1, n - 1) % MOD;

    for (int i = 1; i < 1 << n; ++i) {
   
        int cnt = 0;
        LL a = n + m - 1, b = n - 1;
        for (int j = 0; j < n; ++j) {
   
            if (i >> j & 1) {
   
                a -= (A[j] + 1);
                cnt++;
            }
        }
        (cnt & 1) ? ans = (ans - C(a, b)) % MOD : ans = (ans + C(a, b)) % MOD;
    }

    ans = (ans % MOD + MOD) % MOD;
    cout << ans << '\n';

    return 0;
}

a c ac ac代码

#include <bits/stdc++.h>

using namespace std;

typedef long long LL;
const int N = 25, MOD = 1e9 + 7;

int n;
LL A[N], m;
LL infact;

LL q_pow(LL a, LL b, int mod) {
   
    LL ans = 1;
    a %= mod;
    while (b) {
   
        if (b & 1) ans = ans % mod * a % mod;
        a = a % mod * a % mod;
        b >>= 1;
    }
    return ans;
}

LL C(LL a, LL b) {
   
    if (a < b || b < 0) return 0;
    LL up = 1;
    for (LL i = a; i >= a - b + 1; --i) {
   
        LL val = (i % MOD + MOD) % MOD;
        up = up % MOD * val % MOD;
    }
    LL ans = up * infact % MOD;
    return ans;
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n >> m;
    for (int i = 0; i < n; ++i) cin >> A[i];

    infact = 1;
    for (LL i = 1; i <= n - 1; ++i) infact = infact * i % MOD;
    infact = q_pow(infact, MOD - 2, MOD) % MOD;

    LL ans = C(n + m - 1, n - 1) % MOD;

    for (int i = 1; i < 1 << n; ++i) {
   
        int cnt = 0;
        LL a = n + m - 1, b = n - 1;
        for (int j = 0; j < n; ++j) {
   
            if (i >> j & 1) {
   
                a -= (A[j] + 1);
                cnt++;
            }
        }
        (cnt & 1) ? ans = (ans - C(a, b) + MOD) % MOD : ans = (ans + C(a, b)) % MOD;
    }

    ans = (ans % MOD + MOD) % MOD;
    cout << ans << '\n';

    return 0;
}