G-subsequence 1

题意

给你两个字符串\(s、t\),问\(s\)中有多少个子序列能大于\(t\)

思路

\(len1\)\(s\)的子序列的长度,\(lent\)\(t\)的长度

  1. \(len1 > lent\):枚举每一位,如果当前为不为\(0\)那么它后面的位置可以随意取,\(num = num + \binom{k}{len-1}、k\)是当前位到\(s\)的末尾剩下的位数
  2. \(len1 = lent\):暴力\(n^3\)肯定超时,所以要用\(dp\)优化
    \(dp[i][j][1]\)\(s[j]\)作为第\(i\)个数大于\(t[1\)~\(i]\)前缀的个数
    \(dp[i][j][2]\)\(s[j]\)作为第\(i\)个数等于\(t[1\)~\(i]\)前缀的个数
    • \(s[j] > t[i]\)\(dp[i][j][1] = dp[i-1][1\) ~ \(j-1][1]+dp[i-1][1\) ~ \(j-1][2]\)\(dp[i][j][2] = 0\)

    • \(s[j] = t[i]\)\(dp[i][j][1] = dp[i-1][1\) ~ \(j-1]\)\(dp[i][j][2] = dp[i-1][1\) ~ \(j-1][2]\)

    • \(s[j] < t[i]\)\(dp[i][j][1] = dp[i-1][1\) ~ \(j-1]\)\(dp[i][j][2] = 0\)
  3. 用一个前缀和维护一下\(dp[i-1]\)的前缀,就可以把\(dp\)优化到\(n^2\)

AC 代码

#include<bits/stdc++.h>
#define mes(a, b) memset(a, b, sizeof a)
using namespace std;
typedef long long ll;
const int maxn = 3e3+10;
const ll mod = 998244353;
struct A{
    int num[3][maxn];
    void init(){
        mes(num, 0);
    }
}a, b;
char s[maxn], t[maxn];
ll dp[maxn][maxn][3];
ll C[maxn][maxn];
void init(){    //组合数打表
    C[0][0] = C[1][0] = C[1][1] = 1;
    for(int i = 2; i < maxn;i++){
        for(int j = 0; j <= i; j++){
            C[i][j] = j==0?1:C[i-1][j-1]+C[i-1][j];
            C[i][j] %= mod;
        }
    }
}
 
int main(){
    int T, n, m;
    scanf("%d", &T);
    init();
    while(T--){
        scanf("%d%d", &n, &m);
        scanf("%s%s", s+1, t+1);
        ll ans = 0;
        a.init();   //表示dp[i-1]的前缀和
        b.init();   //表示dp[i]的前缀和
        for(int i = 1; i <= n-m; i++){
            if(s[i] != '0')
                for(int j = m; j <= n-i; j++){
                    ans = (ans + C[n-i][j])%mod;
                }
        }
        for(int i = 0; i <= n; i++){    //初始化
            a.num[2][i] = 1;        
        }
        for(int i = 1; i <= m; i++){
            for(int j = 1; j <= n; j++){
                if(s[j] > t[i]){
                    dp[i][j][1] = (a.num[1][j-1]+a.num[2][j-1])%mod;
                    dp[i][j][2] = 0;
                }
                else if(s[j] == t[i]){
                    dp[i][j][1] = a.num[1][j-1];
                    dp[i][j][2] = a.num[2][j-1];
                }
                else{
                    dp[i][j][1] = a.num[1][j-1];
                    dp[i][j][2] = 0;
                }
                b.num[1][j] = (b.num[1][j-1] + dp[i][j][1])%mod;
                b.num[2][j] = (b.num[2][j-1] + dp[i][j][2])%mod;
            }
            swap(a, b);
            b.num[1][0] = b.num[2][0] = 0;
        }
        ans = (ans + a.num[1][n])%mod;
        printf("%lld\n", ans);
    }
    return 0;
}