更新ing

菜鸡 \(wljss\) 来讲组合数学啦。

组合数学博大精深,主要是爱数数的人上大学了,从模拟赛到NOI都出现过。

一些技巧可以看这里

其实看完上边那个也就没啥好说的了

上例题吧

P1595 信封问题

这里有讲解,直接放代码233

#include<iostream>
using namespace std;
long long ans[21];
int main() 
{
    int n;
    cin >> n;
    ans[2] = 1; ans[3] = 2;
    for (int i = 4; i <= 20; ++i)
        ans[i] = (i - 1) * (ans[i - 1] + ans[i - 2]);
    cout << ans[n];
    return 0;
}

P3197 [HNOI2008]越狱

首先我们可以发现总的方案很好求,就是 \(m^n\).

所以我们采用求 补集 的思想

有多少可能不会发生越狱呢?

第一个人有 \(m\) 种可能的信仰。

从第二个人开始,每个人的信仰都要和上一个人不同,每个人有 \(m-1\) 种。

所以答案就是 \(m^n - m \times (m-1)^{n-1}\)

#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
LL n, m;
const LL mod = 100003;
LL ksm(LL a, LL b, LL mod) 
{
    LL ans = 1;
    for (; b; b >>= 1, a = a * a % mod) {if (b & 1)ans = ans * a % mod;}
    return ans;
}
int main() 
{
    cin >> m >> n;
    cout << (ksm(m, n, mod) + mod - m * ksm(m - 1, n - 1, mod) % mod) % mod;
    return 0;
}

P1057 传球游戏

\(dp[i][j]\) 为第 \(i\) 个人在第 \(j\) 轮拿到球的方案数。

每次考虑一个人的求可能从哪个方向上传过来.

\(f[i][j]=f[i-1][j-1]+f[i+1][j-1]\)

当然 \(1\) 号和 \(n\) 号要特殊处理一下。

#include<iostream>
using namespace std;
int f[31][31];
int main() 
{
    int n, m;
    cin >> n >> m;
    f[1][0] = 1;
    for (int i = 1; i <= m; ++i) 
    {
        f[1][i] = f[2][i - 1] + f[n][i - 1];
        for (int j = 2; j <= n - 1; ++j) 
            f[j][i] = f[j - 1][i - 1] + f[j + 1][i - 1];
        f[n][i] = f[n - 1][i - 1] + f[1][i - 1];
    }
    cout << f[1][m];
    return 0;
}

P6057 [加油武汉]七步洗手法

总的三元环个数很好求,任意选出 \(3\) 个点就能组成,所以总的三元环个数是 \(C_n^3\)

我们采用求 补集 的思想,不同色的三元环有多少个?

考虑每个点,如果有 \(d\) 条白边,那么就有 \(n-d-1\) 条黑边,两两组合就会产生 \(d \times (n-d-1)\) 个三元环

考虑这样我们求出来的是啥?考虑对于每个不同色的三元环,会被每对不同色的边各枚举一次,一共会被枚举 \(2\) 次.

所以\(/2\)后才是真正的不同色三元环个数。

#include<iostream>
#include<cstdio>
using namespace std;
int n, m;
long long tmp;
const int N = 100010;
int du[N];
int main() 
{
    cin >> n >> m;
    for (int i = 1, x, y; i <= m; ++i)
        scanf("%d%d", &x, &y), ++du[x], ++du[y];
    for (int i = 1; i <= n; ++i)tmp += (long long)du[i] * (n - du[i] - 1);
    cout << (long long)n*(n - 1)*(n - 2) / 6 - tmp / 2;
    return 0;
}

P1535 [USACO08MAR]Cow Travelling S

简单 \(DP\),设 \(dp[i][j][k]\) 为走了 \(k\) 步,走到坐标 \((i,j)\) 的方案数。

转移的话直接从周围 \(4\) 个方向转移就行了。

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
int n, m, t, x1, y1, x2, y2;
char mp[105][105];
int l[105][105];
long long dp[105][105][20];
int fx[10], fy[10];
int main() 
{
    cin >> n >> m >> t;
    for (int i = 1; i <= n; ++i)scanf("%s", mp[i] + 1);
    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= m; ++j)
            if (mp[i][j] == '*')l[i][j] = 1;
    cin >> x1 >> y1 >> x2 >> y2;
    fx[1] = 0; fx[2] = 0; fx[3] = 1; fx[4] = -1;
    fy[1] = 1; fy[2] = -1; fy[3] = 0; fy[4] = 0;
    dp[x1][y1][0] = 1;
    for (int k = 1; k <= t; ++k)
        for (int i = 1; i <= n; ++i)
            for (int j = 1; j <= m; ++j)
                for (int d = 1; d <= 4; ++d)
                    if (!l[i + fx[d]][j + fy[d]])
                        dp[i][j][k] += dp[i + fx[d]][j + fy[d]][k - 1];
    printf("%lld", dp[x2][y2][t]);
    fclose(stdin); fclose(stdout);
    return 0;
}

P1144 最短路计数

普通求最短路的话一定要用DIJ,SPFA早就死了。

\(f[i]\)当前情况下 \(i\) 的最短路计数

首先初始化 \(f[1]=1\)

然后考虑求最短路的过程.

如果 \(dis[to[i]] > dis[x] + 1\),由于最短路更新了, 当前情况最短路只能由 \(x\) 走过来,\(f[to[i]] = f[x]\)

如果 \(dis[to[i]] = dis[x] + 1\),最短路没有更新, 当前情况最短路也能由 \(x\) 走过来,\(f[to[i]] += f[x]\)

跑最短路的时候更新就行了.

#include<iostream>
#include<cstring>
#include<cstdio>
#include<queue>
#define pr pair<int,int>
using namespace std;
int n, m, tot;
const int N = 1000010, M = 2000010, mod = 100003;
int head[N], to[M << 1], nt[M << 1], dis[N], f[N], vis[N];
priority_queue<pr>q;
void add(int f, int t) 
{
    to[++tot] = t; nt[tot] = head[f]; head[f] = tot;
}
int main() 
{
    cin >> n >> m;
    for (int i = 1, x, y; i <= m; ++i)
        scanf("%d%d", &x, &y), add(x, y), add(y, x);
    memset(dis, 0x3f, sizeof(dis));
    dis[1] = 0; f[1] = 1; q.push(pr(0, 1));
    while (!q.empty()) 
    {
        int x = q.top().second; q.pop();
        if (vis[x])continue; vis[x] = 1;
        for (int i = head[x]; i; i = nt[i])
            if (dis[to[i]] == dis[x] + 1) {(f[to[i]] += f[x]) %= mod;}
            else if (dis[to[i]] > dis[x] + 1) 
            {
                dis[to[i]] = dis[x] + 1; f[to[i]] = f[x];
                q.push(pr(-dis[to[i]], to[i]));
            }
    }
    for (int i = 1; i <= n; ++i)printf("%d\n", f[i]);
    return 0;
}

P1450 [HAOI2008]硬币购物

首先每次背包一下答案是对的,但是复杂度太高。

对于这种有选取数量限制的计数题,我们通常枚举有哪些突破了数量限制,然后容斥。

对于这道题来说,我们先用四种面值做一个完全背包。

然后枚举哪些面值肯定会突破限制,这个可以通过 \(tmp -= c * (b + 1)\) 来实现,也就是先选出来 \(b+1\) 个,再怎么选都会突破限制。

答案并不是 有0个强制突破限制的情况,因为虽然没有强制突破的情况,但因为随便选还是有突破的情况。

所以我们需要减去强制有1个突破的情况。

然后会发现减的有点多,要加上强制有2个突破的情况.以此类推。

#include<iostream>
#include<cstdio>
using namespace std;
int n, s, opt;
long long ans, tmp;
int b[5], c[5];
long long f[100010];
int main() 
{
    cin >> c[1] >> c[2] >> c[3] >> c[4] >> n;
    f[0] = 1;
    for (int i = 1; i <= 4; ++i)
        for (int j = c[i]; j <= 100000; ++j)f[j] += f[j - c[i]];
    while (n--) 
    {
        scanf("%d%d%d%d%d", &b[1], &b[2], &b[3], &b[4], &s);
        ans = 0;
        for (int i = 0; i < (1 << 4) - 1; ++i) 
        {
            tmp = s; opt = 1;
            for (int j = 1; j <= 4; ++j)
                if ((i >> (j - 1)) & 1)tmp -= 1ll * c[j] * (b[j] + 1), opt = -opt;
            if (tmp < 0)continue;
            ans += opt * f[tmp];
        }
        cout << ans << '\n';
    }
    return 0;
}

P4071 [SDOI2016]排列计数

首先枚举哪些位置满足 \(a_i=i\),那么剩下的 \(n-m\) 个数就需要错排。

所以答案就是 \(C_n^m \times f[n-m]\)

#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
int T, n, m;
const int N = 1000010, M = 1000000, mod = 1e9 + 7;
LL jc[N], inv[N], f[N];
LL ksm(LL a, LL b, LL mod) 
{
    LL res = 1;
    for (; b; b >>= 1, a = a * a % mod)
        if (b & 1)res = res * a % mod;
    return res;
}
void YYCH() 
{
    jc[0] = jc[1] = inv[0] = inv[1] = 1;
    for (int i = 2; i <= M; ++i)jc[i] = jc[i - 1] * i % mod;
    inv[M] = ksm(jc[M], mod - 2, mod);
    for (int i = M - 1; i >= 1; --i)inv[i] = inv[i + 1] * (i + 1) % mod;

    f[0] = 1; f[1] = 0; f[2] = 1; f[3] = 2;
    for (int i = 4; i <= M; ++i)f[i] = (i - 1) * (f[i - 1] + f[i - 2]) % mod;
}
LL C(int n, int m) {return jc[n] * inv[m] % mod * inv[n - m] % mod;}
int main() 
{
    YYCH();
    cin >> T;
    while (T--) 
    {
        scanf("%d%d", &n, &m);
        printf("%lld\n", C(n, m)*f[n - m] % mod);
    }
    return 0;
}

P2513 [HAOI2009]逆序对数列

\(f[i][j]\) 为用 \(1\) ~ \(i\) 组成的序列有j个逆序对的方案数。

我们将每个数一个一个插入原序列,考虑第 \(i\) 个数放在原来的序列的哪个位置,由于之前的数都比 \(i\) 小,所以如果插在 \(k\) 个数后面,就会增加 \(k\) 个逆序对。

所以 \(\displaystyle f[i][j]=\sum_{k=0}^{min(i-1,j)}f[i-1][j-k]\)

发现后面那一段是连续的一段,可以用前缀和优化一下。

#include<iostream>
#include<cstdio>
using namespace std;
int n, k, tot = 0, p = 10000;
int ans[1001][1001];
int main() 
{
    cin >> n >> k;
    ans[1][0] = 1;
    for (int i = 2; i <= n; ++i) 
    {
        tot = 0;
        for (int j = 0; j <= k; ++j) 
        {
            tot += ans[i - 1][j];
            ans[i][j] = tot % p;
            if (j >= i - 1) 
            {
                tot = tot - ans[i - 1][j - i + 1];
                tot = (tot + p) % p;
            }
        }
    }
    cout << ans[n][k];
    return 0;
}

P5664 Emiya 家今天的饭

考虑容斥:合法方案:总方案-不合法方案

总方案很好求,可以 \(DP\) ,也可以用 \(\displaystyle (\prod_{i=1}^{n} (1 + \sum_{j=1}^{m}a[i][j]))-1\) 来求。

如果一个方案不合法,一定是某个食材出现次数超过了 \(\lfloor \frac{k}{2} \rfloor\) 次,并且能造成这个的最多只有 \(1\) 个食材。

枚举哪个食材不合法,\(DP\)\(g[i][j][k]\) 为前 \(i\) 个里面,一共选了 \(j\) 个,其中 \(id\) 选了 \(k\) 个的方案数

时间复杂度 \(O(m n^3)\) ,只有84分

#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
int n, m, ans;
const int N = 105, M = 2005, mod = 998244353;
int a[N][M];
LL f[N][N], g[N][N][N];
void solve1() 
{
    int tmp;
    f[0][0] = 1;
    for (int i = 1; i <= n; ++i) 
    {
        f[i][0] = 1; tmp = 0;
        for (int j = 1; j <= m; ++j)(tmp += a[i][j]) %= mod;
        for (int j = 1; j <= i; ++j)f[i][j] = (f[i - 1][j] + f[i - 1][j - 1] * tmp) % mod;
    }
}
void solve2(int id) 
{
    int tmp;
    g[0][0][0] = 1;
    for (int i = 1; i <= n; ++i) 
    {
        g[i][0][0] = 1; tmp = 0;
        for (int j = 1; j <= m; ++j)
            if (j != id)(tmp += a[i][j]) %= mod;
        for (int j = 1; j <= i; ++j) 
        {
            for (int k = 0; k <= j; ++k)g[i][j][k] = g[i - 1][j][k];
            for (int k = 0; k <= j; ++k)(g[i][j][k] += g[i - 1][j - 1][k] * tmp) %= mod;
            for (int k = 1; k <= j; ++k)(g[i][j][k] += g[i - 1][j - 1][k - 1] * a[i][id]) %= mod;
        }
    }
}
int main() 
{
    cin >> n >> m;
    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= m; ++j)scanf("%d", &a[i][j]);
    solve1();
    for (int i = 1; i <= m; ++i) 
    {
        solve2(i);
        for (int j = 1; j <= n; ++j)
            for (int k = j / 2 + 1; k <= j; ++k)
                (f[n][j] -= g[n][j][k]) %= mod;
    }
    for (int i = 1; i <= n; ++i)(ans += f[n][i]) %= mod;
    cout << (ans % mod + mod) % mod;
    return 0;
}

发现同时记录 \(j\)\(k\) 就有点多余,重新设 \(g[i][j]\) 为前 \(i\) 个里,\(id\) 选的比其他的多 \(j\) 个时的方案数。

时间复杂度就成了 \(O(mn^2)\),可以过,但是因为可能 \(j\) 有可能是负数,所以需要整体下标加上 \(n\)

#include<iostream>
#include<cstring>
#include<cstdio>
#define LL long long
using namespace std;
int n, m, ans;
const int N = 105, M = 2005, mod = 998244353;
int a[N][M];
LL f[N][N], s[N], g[N][N << 1];
void solve1() 
{
    ans = 1;
    for (int i = 1; i <= n; ++i) 
    {
        for (int j = 1; j <= m; ++j)(s[i] += a[i][j]) %= mod;
        ans = (LL)ans * (s[i] + 1) % mod;
    }
    --ans;
}
void solve2(int id) 
{
    int tmp;
    g[0][n] = 1;
    for (int i = 1; i <= n; ++i) 
    {
        tmp = (s[i] - a[i][id]) % mod;
        for (int j = 0; j <= 2 * n; ++j)g[i][j] = g[i - 1][j];
        for (int j = 0; j <= 2 * n - 1; ++j)(g[i][j] += g[i - 1][j + 1] * tmp) %= mod;
        for (int j = 1; j <= 2 * n; ++j)(g[i][j] += g[i - 1][j - 1] * a[i][id]) %= mod;
    }
}
int main() 
{
    cin >> n >> m;
    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= m; ++j)scanf("%d", &a[i][j]);
    solve1();
    for (int i = 1; i <= m; ++i) 
    {
        solve2(i);
        for (int j = 1; j <= n; ++j)(ans -= g[n][n + j]) %= mod;
    }
    cout << (ans % mod + mod) % mod;
    return 0;
}