题目描述

直接点开题目看描述吧。

正解

数位 dp。

表示考虑前 位,两串 的差值为 的方案数(后面两个 表示两个串分别是否顶上界)。

再设一个辅助数组 表示考虑前 位,两串 的差值为 ,的所有方案下,需要移动的步数(后面两个 与上面定义相同)。

发现如果当前两个串 的个数相差为 可能为负),到下一位就会要产生 的贡献(移动 步)。

转移好像挺复杂的,我是自己画了个转移的图然后才写出来的。

note

由于题目中 串字典序要严格小于 串字典序,答案算出来还要除个

代码

这代码真不是给人看的

#include <bits/stdc++.h>

using namespace std;

const int N = 1005, D = 1002;
const int mod = 998244353;

int n;
char s[N];
unsigned long long f[2][N + N][2][2], g[2][N + N][2][2];

int main() {
    scanf("%s", s + 1), n = strlen(s + 1);
    int u = 0, v = 1;
    f[u][D][1][1] = 1;
    for(int i = 0; i < n; ++i) {
        memset(f[v], 0, sizeof f[v]);
        memset(g[v], 0, sizeof g[v]);
        for(int j = D - i; j <= D + i; ++j) {
            int c = abs(j - D);
            if(s[i + 1] == '0') {
                f[v][j][0][0] += 2 * f[u][j][0][0];
                g[v][j][0][0] += 2 * (g[u][j][0][0] + f[u][j][0][0] * c);
                f[v][j - 1][0][0] += f[u][j][0][0];
                g[v][j - 1][0][0] += g[u][j][0][0] + f[u][j][0][0] * c;
                f[v][j + 1][0][0] += f[u][j][0][0];
                g[v][j + 1][0][0] += g[u][j][0][0] + f[u][j][0][0] * c;
                f[v][j + 1][0][1] += f[u][j][0][1];
                g[v][j + 1][0][1] += g[u][j][0][1] + f[u][j][0][1] * c;
                f[v][j][0][1] += f[u][j][0][1];
                g[v][j][0][1] += g[u][j][0][1] + f[u][j][0][1] * c;
                f[v][j - 1][1][0] += f[u][j][1][0];
                g[v][j - 1][1][0] += g[u][j][1][0] + f[u][j][1][0] * c;
                f[v][j][1][0] += f[u][j][1][0];
                g[v][j][1][0] += g[u][j][1][0] + f[u][j][1][0] * c;
                f[v][j][1][1] += f[u][j][1][1];
                g[v][j][1][1] += g[u][j][1][1] + f[u][j][1][1] * c;
            } else {
                f[v][j][0][0] += 2 * f[u][j][0][0];
                g[v][j][0][0] += 2 * (g[u][j][0][0] + f[u][j][0][0] * c);
                f[v][j - 1][0][0] += f[u][j][0][0];
                g[v][j - 1][0][0] += g[u][j][0][0] + f[u][j][0][0] * c;
                f[v][j + 1][0][0] += f[u][j][0][0];
                g[v][j + 1][0][0] += g[u][j][0][0] + f[u][j][0][0] * c;
                f[v][j][0][0] += f[u][j][0][1];
                g[v][j][0][0] += g[u][j][0][1] + f[u][j][0][1] * c;
                f[v][j - 1][0][1] += f[u][j][0][1];
                g[v][j - 1][0][1] += g[u][j][0][1] + f[u][j][0][1] * c;
                f[v][j + 1][0][0] += f[u][j][0][1];
                g[v][j + 1][0][0] += g[u][j][0][1] + f[u][j][0][1] * c;
                f[v][j][0][1] += f[u][j][0][1];
                g[v][j][0][1] += g[u][j][0][1] + f[u][j][0][1] * c;
                f[v][j][0][0] += f[u][j][1][0];
                g[v][j][0][0] += g[u][j][1][0] + f[u][j][1][0] * c;
                f[v][j - 1][0][0] += f[u][j][1][0];
                g[v][j - 1][0][0] += g[u][j][1][0] + f[u][j][1][0] * c;
                f[v][j + 1][1][0] += f[u][j][1][0];
                g[v][j + 1][1][0] += g[u][j][1][0] + f[u][j][1][0] * c;
                f[v][j][1][0] += f[u][j][1][0];
                g[v][j][1][0] += g[u][j][1][0] + f[u][j][1][0] * c;
                f[v][j][0][0] += f[u][j][1][1];
                g[v][j][0][0] += g[u][j][1][1] + f[u][j][1][1] * c;
                f[v][j - 1][0][1] += f[u][j][1][1];
                g[v][j - 1][0][1] += g[u][j][1][1] + f[u][j][1][1] * c;
                f[v][j + 1][1][0] += f[u][j][1][1];
                g[v][j + 1][1][0] += g[u][j][1][1] + f[u][j][1][1] * c;
                f[v][j][1][1] += f[u][j][1][1];
                g[v][j][1][1] += g[u][j][1][1] + f[u][j][1][1] * c;
            }
        }
        for(int j = D - (i + 1); j <= D + i + 1; ++j) {
            f[v][j][0][0] %= mod;
            f[v][j][0][1] %= mod;
            f[v][j][1][0] %= mod;
            f[v][j][1][1] %= mod;
            g[v][j][0][0] %= mod;
            g[v][j][0][1] %= mod;
            g[v][j][1][0] %= mod;
            g[v][j][1][1] %= mod;
        }
        swap(u, v);
    }

    unsigned long long ans = 0;
    ans = (g[u][D][0][0] + g[u][D][0][1] + g[u][D][1][0]) % mod;
    ans = ans * (mod + 1) / 2 % mod;
    printf("%llu\n", ans);
    return 0;
}