题解 | CSP-S 2021 括号序列

目录

状态定义

定义 dp[l][r][flag]dp[l][r][flag] 为区间 [l,r][l, r] 中符合以下条件的方法数:

  1. flag = 0 表示左右两端为匹配括号的形式:(any)(any)

  2. flag = 1 表示左端为1-k个星号的形式: ..(any)*..*(any)

  3. flag = 2 表示右端为1-k个星号的形式: (any)..(any)*..*

则所求答案为dp[0][n1][0]dp[0][n-1][0]

状态转移方程

1. flag = 0时, 枚举左边第1个合法的超级括号序列

dp[l][r][0]=count(l,r)+i=l+1r1count(l,i)(dp[i+1][r][0]+dp[i+1][r][1])dp[l][r][0] = count(l, r) + \sum_{i=l+1}^{r-1}count(l, i) * (dp[i+1][r][0] + dp[i+1][r][1])
count(l,r)=countcase1(l,r)+i=03dp[l+1][r1][i]count(l, r) = count_{case1}(l,r) + \sum_{i=0}^{3}dp[l+1][r-1][i]

其中要求

  • s[l] = '(' 或 '?'

  • s[i] = ')' 或 '?'

  • s[r] = ')' 或 '?'

  • s[l..i]为第1个有如下形式的合法括号序列(),(..),((any)),(..(any)),((any)..)(), (*..*), ((any)), (*..*(any)), ((any)*..*)

  • count(l,r)count(l, r) 统计如下形式(),(..),((any)),(..(any)),((any)..) (), (*..*), ((any)), (*..*(any)), ((any)*..*) 的合法串数量

  • countcase1(l,r)=1count_{case1}(l, r) = 1 当且仅当 s[l, r]形如()()()(***)

这一步时间复杂度为O(kN2)O(kN^2)


2. flag = 1时

dp[l][r][1]=i=1kdp[l+i][r][0]dp[l][r][1] = \sum_{i=1}^{k}dp[l+i][r][0]

其中要求

  • s[l+i] = '(' 或 '?'

  • s[r] = ')' 或 '?'

  • s[l, l+i-1] 中每个字符 = '*' 或 '?'

这一步时间复杂度为O(kN2)O(kN^2)


3. flag = 2时

dp[l][r][2]=i=1kdp[l][ri][0]dp[l][r][2] = \sum_{i=1}^{k}dp[l][r-i][0]

其中要求

  • s[l] = '(' 或 '?'

  • s[r-i] = ')' 或 '?'

  • s[r-i+1, r] 中每个字符 = '*' 或 '?'

这一步时间复杂度为O(kN2)O(kN^2)


总的时间复杂度为O(N3)O(N^3)

代码

#include <bits/stdc++.h>
#define CLEAR(a,val) memset(a, val, sizeof (a))
using ll = long long;
using namespace std;
const ll MOD = 1e9 + 7;
const int MAXN = 501;
int mem[MAXN][MAXN][3];
int mem_case1[MAXN][MAXN];

int main() {
    CLEAR(mem, -1);
    CLEAR(mem_case1, -1);
    int n, k; cin >> n >> k;
    string s; cin >> s;
    // can s[i] turn into c
    auto check = [&s](int i, char c)->bool {
        return s[i] == c || s[i] == '?';
    };

    // trivial case
    auto count_case1 = [&](int l, int r) -> int {
        if (!check(l, '(')) return 0;
        if (!check(r, ')')) return 0;
        if (r - l - 1 > k) return 0;
        if (mem_case1[l][r] != -1) { return mem_case1[l][r]; }
        for(int i = l + 1; i <= r - 1; ++i) {
            if (!check(i, '*')) return mem_case1[l][r] = 0;
        }
        return mem_case1[l][r] = 1;
    };

    function<ll(int, int, int)> dp = [&](int l, int r, int flag) -> ll {
        if (r - l <= 0) return 0;
        if (mem[l][r][flag] != -1) { return mem[l][r][flag]; }
        ll ans = 0;
        if (flag == 0) { // format (~)
            if (!check(l, '(') || !check(r, ')')) {
                return mem[l][r][flag] = 0;
            }
            // () = count(l, r)
            ans = (ans + count_case1(l, r)) % MOD;
            for(int i = 0; i < 3; ++i) {
                ans = (ans + dp(l + 1, r - 1, i)) % MOD;
            }
            // ()any = range dp
            for(int i = l + 1; i <= r - 1; ++i) {
                if (check(i, ')')) {
                    ll cnt_left = count_case1(l, i) + dp(l + 1, i - 1, 0) + dp(l + 1, i - 1, 1) + dp(l + 1, i - 1, 2);
                    ll cnt_right = dp(i + 1, r, 0) + dp(i + 1, r, 1);
                    cnt_left %= MOD;
                    cnt_right %= MOD;
                    ans = (ans + cnt_left * cnt_right % MOD) % MOD;
                }
            }
        }
        else if (flag == 1 && check(r, ')')) { // format ***()
            for(int i = 1; i <= k; ++i) {
                int at = l + i - 1;
                if (at + 1 >= r) { break; }
                if (!check(at, '*')) { break; }
                if (check(at + 1, '(')) {
                    ans = (ans + dp(at + 1, r, 0)) % MOD;
                }
            }
        }
        else if (flag == 2 && check(l, '(')) { // format ()***
            for(int i = 1; i <= k; ++i) {
                int at = r - i + 1;
                if (at - 1 <= l) { break; }
                if (!check(at, '*')) { break; }
                if (check(at - 1, ')')) {
                    ans = (ans + dp(l, at - 1, 0)) % MOD;
                }
            }
        }
        return mem[l][r][flag] = int(ans % MOD);
    };
    cout << dp(0, n - 1, 0) << endl;

    return 0;
}