看了看dalao们的解释和代码,于是有了这篇题解

题目大意

给定两个字符串A,B,求出满足以下条件的子序列a,b(可以不连续)的数量 ,并对取模:

  • a,b分别来自A,B,且长度相同
  • ,使得
  • ​,满足​​。
  • 对于​没有任何限制。

思路

  • 假如我们已经找到了,和相同,那么我们只需要在数组后面中找到一个字母 ,在数组后面中找到一个字母,满足​即可,那么剩下的字母就可以随便选了​。
  • 假设数组还剩个数可选,数组还剩个可选,那么利用组合数学,就可知道共有种选法。而此式子等于。(具体证明可以参照百度)。(蒟蒻不太会)
  • 那么剩下的就是统计​数组有多少个字串是相同的了。我们假设​表示在​数组前​位,​数组的前​位***有几个相同的字串,递推时(有点类似与最长公共子序列):
    • 如果​,那么它显然是由​数组前​个字母和​数组前​个组成,或者是​数组前​个字母和​数组前​个组成的。但是如果我们直接加上​的话,会出现多加的情况,那就是​数组前​个字母和​数组前​​个组成的,所以我们要再减去,综上:​​。​
    • 如果,那么它显然也可以由得到一部分结果,其他的部分,就是由数组前个字母和数组前个,再加上两个字母,即,其次两个字母也可以单独作为结果计算,综上:
  • 因为结果是要求对结果进行取模的,而且我们的计算中是存在组合数的,所以不能每次都用费马小定理,所以我们需要预处理出阶乘数组和其对应的乘法逆元数组。

代码

#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;

typedef  long long ll;
const int maxN = 2000005, mod = 1000000007;

ll inv[maxN + 1], f[maxN + 1];

string s1, s2;
long long dp[5005][5005];
int len1, len2;

void exgcd(int a, int b, ll &x, ll &y) //拓展欧几里得
{
    if(b == 0) {
        x = 1; y = 0;
        return ;
    }
    exgcd(b, a % b, y, x);
    y -= a / b * x;
}

void init()
{
    f[0] = 1;
    for(int i = 1; i <= maxN; ++i)
        f[i] = f[i - 1] * i % mod; //阶乘数组
    ll x, y;
    exgcd(f[maxN], mod, x, y);
    inv[maxN] = (x % mod + mod) % mod;
    for(int i = maxN - 1; i; --i) {
        inv[i] = inv[i + 1] * (i + 1) % mod; //逆元数组
    }
}

ll C(ll n, ll m)
{
    if(n == m || m == 0)
        return 1;
    if(m > n)
        return 0;
    return (f[n] * inv[m] % mod * inv[n - m] % mod) % mod;
}

int main()
{
    cin >> s1; cin >> s2;
    len1 = s1.length(); len2 = s2.length();
    s1 = " " + s1; s2 = " " + s2;
    init();
    for(int i = 1; i <= len1; ++i) {
        for(int j = 1; j <= len2; ++j) {
            if(s1[i] == s2[j])
                dp[i][j] = (dp[i - 1][j] + dp[i][j - 1] + 1) % mod;
            else
                dp[i][j] = ((dp[i - 1][j] + dp[i][j - 1]) % mod + mod - dp[i - 1][j - 1]) % mod;
        }
    }
    long long ans = 0;
    for(int i = 1; i <= len1; ++i) {
        for(int j = 1; j <= len2; ++j) {
            if(s1[i] < s2[j]) {
                long long n = len1 - i, m = len2 - j;
                ans = (ans + (dp[i - 1][j - 1] + 1ll) * C(1LL * n + m, 1LL * min(n, m)) % mod) % mod;
            }
        }
    }
    cout << ans << endl;
    return 0;
}