题解 | CSP-S 2021 括号序列
目录
状态定义
定义 为区间 中符合以下条件的方法数:
-
flag = 0 表示左右两端为匹配括号的形式:
-
flag = 1 表示左端为1-k个星号的形式:
-
flag = 2 表示右端为1-k个星号的形式:
则所求答案为
状态转移方程
1. flag = 0时, 枚举左边第1个合法的超级括号序列
其中要求
-
s[l] = '(' 或 '?'
-
s[i] = ')' 或 '?'
-
s[r] = ')' 或 '?'
-
s[l..i]为第1个有如下形式的合法括号序列
-
统计如下形式 的合法串数量
-
当且仅当 s[l, r]形如或
这一步时间复杂度为
2. flag = 1时
其中要求
-
s[l+i] = '(' 或 '?'
-
s[r] = ')' 或 '?'
-
s[l, l+i-1] 中每个字符 = '*' 或 '?'
这一步时间复杂度为
3. flag = 2时
其中要求
-
s[l] = '(' 或 '?'
-
s[r-i] = ')' 或 '?'
-
s[r-i+1, r] 中每个字符 = '*' 或 '?'
这一步时间复杂度为
总的时间复杂度为
代码
#include <bits/stdc++.h>
#define CLEAR(a,val) memset(a, val, sizeof (a))
using ll = long long;
using namespace std;
const ll MOD = 1e9 + 7;
const int MAXN = 501;
int mem[MAXN][MAXN][3];
int mem_case1[MAXN][MAXN];
int main() {
CLEAR(mem, -1);
CLEAR(mem_case1, -1);
int n, k; cin >> n >> k;
string s; cin >> s;
// can s[i] turn into c
auto check = [&s](int i, char c)->bool {
return s[i] == c || s[i] == '?';
};
// trivial case
auto count_case1 = [&](int l, int r) -> int {
if (!check(l, '(')) return 0;
if (!check(r, ')')) return 0;
if (r - l - 1 > k) return 0;
if (mem_case1[l][r] != -1) { return mem_case1[l][r]; }
for(int i = l + 1; i <= r - 1; ++i) {
if (!check(i, '*')) return mem_case1[l][r] = 0;
}
return mem_case1[l][r] = 1;
};
function<ll(int, int, int)> dp = [&](int l, int r, int flag) -> ll {
if (r - l <= 0) return 0;
if (mem[l][r][flag] != -1) { return mem[l][r][flag]; }
ll ans = 0;
if (flag == 0) { // format (~)
if (!check(l, '(') || !check(r, ')')) {
return mem[l][r][flag] = 0;
}
// () = count(l, r)
ans = (ans + count_case1(l, r)) % MOD;
for(int i = 0; i < 3; ++i) {
ans = (ans + dp(l + 1, r - 1, i)) % MOD;
}
// ()any = range dp
for(int i = l + 1; i <= r - 1; ++i) {
if (check(i, ')')) {
ll cnt_left = count_case1(l, i) + dp(l + 1, i - 1, 0) + dp(l + 1, i - 1, 1) + dp(l + 1, i - 1, 2);
ll cnt_right = dp(i + 1, r, 0) + dp(i + 1, r, 1);
cnt_left %= MOD;
cnt_right %= MOD;
ans = (ans + cnt_left * cnt_right % MOD) % MOD;
}
}
}
else if (flag == 1 && check(r, ')')) { // format ***()
for(int i = 1; i <= k; ++i) {
int at = l + i - 1;
if (at + 1 >= r) { break; }
if (!check(at, '*')) { break; }
if (check(at + 1, '(')) {
ans = (ans + dp(at + 1, r, 0)) % MOD;
}
}
}
else if (flag == 2 && check(l, '(')) { // format ()***
for(int i = 1; i <= k; ++i) {
int at = r - i + 1;
if (at - 1 <= l) { break; }
if (!check(at, '*')) { break; }
if (check(at - 1, ')')) {
ans = (ans + dp(l, at - 1, 0)) % MOD;
}
}
}
return mem[l][r][flag] = int(ans % MOD);
};
cout << dp(0, n - 1, 0) << endl;
return 0;
}