https://blog.csdn.net/qq_43804974/article/details/103905708
上面是csdn链接给点访问量吧

题目大意就是给一个长的字符串a和一个短的字符串b,询问a里面有多少个序列是大于b的。

这里我们分两步来处理,就是我们先算出a中的序列长度和b一样的。因为显然答案是由序列长度一样的数量+序列长度>b的数量。为什么不能直接一起算,因为dp的限制(下面讲)?

(1)计算序列长度一样的数量。我们这么考虑由于序列长度是一样的,我们就可以枚举位数的时候有一个比较值。我们类比一下数位dp,数位dp是只要我的高位比限制的范围小,那么地位就可以0-9随便取。转到到我们这道题就是,
我们要最终序列比b大,那么就是我在高位只要有某一位比b的对应位大,那么我下面是随便取都可以满足要求。
我们就这么去设计f[i][j][0/1],i表示现在枚举到b串的第几位,j表示现在枚举到a串的第几位,如果新序列第i位为a[j]那么合法序列是多少个。0表示这个序列正好和b序列一样,1表示高位的某一位已经大于b了所以后面瞎取都可以
这里对于方程的转移就要分三种情况

  1. a[j] < b[i]; 如果a中第j的数字是小于b中的i的数字,那么很明显,我们如果要出现一个大于b的有效序列那么只能是之前的高位已经有大于b所对应数字。即: f[i][j][1] = sum(f[i-1][k][1],(0=<k<= j-1)).,这里无法转移出f[i][j][0]因为你在这一位已经比b小了,不可能出现一样的序列了。
  2. a[j] == b[i] 那么我们要出现合法序列就是前方的序列不管是大于b还是正好等于b都一起加上,对于f[i][j][0] = sum(f[i-1][k][0],(0=<k<= j-1)).,这个式子不能加上f[i-1][k][1]因为你前面大于标准序列了,那么你后面不可能出现与标准序列一样的序列。但是,这个情况我们还可以推出f[i][j][1],因为只要选取前面的序列大于标准序列的就好了,方程和1的方程同理。
  3. a[j] > b[i] 好这一位的值比标准序列的那一位大,那么说明什么?我们之前出现的刚好等于的,和以及超过的都可以算加上去,也就是f[i][j][1] = sum(f[i-1][k][0]+f[i-1][k][1])(0=<k<= j-1),因为我们这一位是比标准的序列大的,所以可以吧前面不管是正好等于的,还是超过的都加上。
    这就是答案的一部分,很明显上面的sum是可以利用前缀和数组去O(1)算出来的,看到这里你知道为什么这种方程不能转移出序列长度大于标准序列的情况了吗,因为比标准序列高位的状态我们无法表示,存标准序列是[1,m]但是我们要表示高位向地位也是[1,k],如果用1表示最高位,就于存入数据的b[1]矛盾。因为我们本来是原本没有存在的最高位只能是0.

我们来考虑另一部分,我们怎么算长度大于标准序列的数目?这里由于我菜 没去想到组合数的方法。我用了第二个dp。
用dp[i][j]表示i个数选了j个的方案数。但是当j等于1的时候如果i所代表的数字是0那么则不能+1,反之就可以+1,这个dp就蛮简单的了。
dp[i][j] += dp[i-1][j];//这个数字我选了
dp[i][j] += dp[i-1][j-1];//这个数字不选
稍稍特判j等于1的情况就好了
dp[i][1] = dp[i - 1][1];如果a[i] != '0'则dp[i][1]++;

最后!两个答案一加就是了!

#include<iostream>
#include<cstdio>
#include<stack>
#include<algorithm>
#include<cstring>
#include<queue>
#include<vector>
#include<time.h>
#include<string>
#include<cmath>
#include <ctime>
#include<bitset>
#include <cctype>
#define debug cout<<"*********degug**********"<<endl;
#define ll long long
#define yn yn_
#define RE register
using namespace std;
const long long max_ = 3000 + 10;
const int mod = 998244353;
const int inf = 1e9;
const long long INF = 1e18;
int read() {
    int s = 0, f = 1;
    char ch = getchar();
    while (ch<'0' || ch>'9') {
        if (ch == '-')
            f = -1;
        ch = getchar();
    }
    while (ch >= '0'&&ch <= '9') {
        s = s * 10 + ch - '0';
        ch = getchar();
    }
    return s * f;
}
int max(int a, int b) {
    return a > b ? a : b;
}
int min(int a, int b) {
    return a < b ? a : b;
}
int T, n, m, f[max_][max_][3], sum0[max_][max_], sum1[max_][max_], g[max_][max_];
char nan_[max_], bei[max_];
signed main() {
    cin >> T;
    while (T--) {
        cin >> n >> m;
        scanf("%s", nan_ + 1);
        scanf("%s", bei + 1);
        //cin >> (nan_ + 1);
        //cin >> (bei + 1);
        //找nan_的子序列比bei大的数量
        f[0][0][0] = 1; sum0[0][0] = 1;
        for (int j = 1; j <= n; j++) {
            sum0[0][j] += sum0[0][j - 1];
        }

        for (int i = 1; i <= m; i++) {
            //现在是bei的第i位
            for (int j = 1; j <= n; j++) {
                if (nan_[j] - '0' < bei[i] - '0') {
                    f[i][j][1] += (sum1[i - 1][j - 1] % mod); f[i][j][1] %= mod;
                }
                else
                    if (nan_[j] - '0' == bei[i] - '0') {
                        f[i][j][0] += sum0[i - 1][j - 1]; f[i][j][0] %= mod;
                        f[i][j][1] += sum1[i - 1][j - 1]; f[i][j][1] %= mod;
                    }
                    else {
                        f[i][j][1] += ((sum0[i - 1][j - 1] % mod) + (sum1[i - 1][j - 1] % mod)) % mod; f[i][j][1] %= mod;
                    }
                sum0[i][j] += sum0[i][j - 1] + f[i][j][0]; sum0[i][j] %= mod;
                sum1[i][j] += sum1[i][j - 1] + f[i][j][1]; sum1[i][j] %= mod;
            }
        }
        int ans = 0;
        ans += sum1[m][n];
        ans %= mod;
        for (int i = 1; i <= n; i++) {
            g[i][1] = g[i - 1][1]; g[i][1] %= mod;
            if (nan_[i] - '0' != 0)
                g[i][1]++;
            for (int j = 2; j <= i; j++) {
                g[i][j] = g[i - 1][j];
                g[i][j] %= mod;
                g[i][j] += g[i - 1][j - 1];
                g[i][j] %= mod;
            }
        }
        for (int j = m + 1; j <= n; j++) {
            ans += g[n][j];
            ans %= mod;

        }
        cout << ans << "\n";
        int tt = max(n, m) + 1;
        for (int i = 0; i <= tt; i++) {
            for (int j = 0; j <= tt; j++) {
                f[i][j][0] = f[i][j][1] = sum0[i][j] = sum1[i][j] = g[i][j] = 0;
            }
        }
    }
    return 0;
}