状态定义
定义 d p [ l ] [ r ] [ f l a g ] dp[l][r][flag] dp[l][r][flag] 为区间 [ l , r ] [l,r] [l,r] 中符合以下条件的方法数:
- flag = 0 表示左右两端为匹配括号的形式: ( a n y ) (any) (any)
- flag = 1 表示左端为1-k个星号的形式: ∗ . . ∗ ( a n y ) ∗..∗(any) ∗..∗(any)
- flag = 2 表示右端为1-k个星号的形式: ( a n y ) ∗ . . ∗ (any)∗..∗ (any)∗..∗
则所求答案为 d p [ 0 ] [ n − 1 ] [ 0 ] dp[0][n−1][0] dp[0][n−1][0]
状态转移方程
1. flag = 0时, 枚举左边第1个合法的超级括号序列
d p [ l ] [ r ] [ 0 ] = c o u n t ( l , r ) + ∑ i = l + 1 r − 1 c o u n t ( l , i ) ∗ ( d p [ i + 1 ] [ r ] [ 0 ] + d p [ i + 1 ] [ r ] [ 1 ] ) dp[l][r][0]=count(l,r)+\sum_{i=l+1}^{r-1}count(l,i)*(dp[i+1][r][0]+dp[i+1][r][1]) dp[l][r][0]=count(l,r)+i=l+1∑r−1count(l,i)∗(dp[i+1][r][0]+dp[i+1][r][1])
c o u n t ( l , r ) = c o u n t c a s e 1 ( l , r ) + ∑ i = 0 3 d p [ l + 1 ] [ r − 1 ] [ i ] count(l,r)=count_{case1}(l,r)+\sum_{i=0}^{3}dp[l+1][r-1][i] count(l,r)=countcase1(l,r)+i=0∑3dp[l+1][r−1][i]
其中要求
-
s [ l ] s[l] s[l] = ‘(’ 或 ‘?’
-
s [ i ] s[i] s[i] = ‘)’ 或 ‘?’
-
s [ r ] s[r] s[r] = ‘)’ 或 ‘?’
-
s [ l . . i ] s[l..i] s[l..i]为第1个有如下形式的合法括号序列 ( ) , ( ∗ . . ∗ ) , ( ( a n y ) ) , ( ∗ . . ∗ ( a n y ) ) , ( ( a n y ) ∗ . . ∗ ) (), (∗..∗), ((any)), (∗..∗(any)), ((any)∗..∗) (),(∗..∗),((any)),(∗..∗(any)),((any)∗..∗)
-
c o u n t ( l , r ) count(l,r) count(l,r) 统计如下形式 ( ) , ( ∗ . . ∗ ) , ( ( a n y ) ) , ( ∗ . . ∗ ( a n y ) ) , ( ( a n y ) ∗ . . ∗ ) (), (∗..∗), ((any)), (∗..∗(any)), ((any)∗..∗) (),(∗..∗),((any)),(∗..∗(any)),((any)∗..∗) 的合法串数量
-
c o u n t c a s e 1 ( l , r ) = 1 count_{case1}(l, r) = 1 countcase1(l,r)=1 当且仅当 s [ l , r ] s[l, r] s[l,r] 形如 ( ) () () 或 ( ∗ ∗ ∗ ) (∗∗∗) (∗∗∗)
这一步时间复杂度为 O ( k N 2 ) O(kN^2) O(kN2)
2. flag = 1时
d p [ l ] [ r ] [ 1 ] = ∑ i = 1 k d p [ l + i ] [ r ] [ 0 ] dp[l][r][1]=\sum_{i=1}^{k}dp[l+i][r][0] dp[l][r][1]=i=1∑kdp[l+i][r][0]
其中要求
-
s [ l + i ] = s[l+i]= s[l+i]= ‘(’ 或 ‘?’
-
s [ r ] = s[r]= s[r]= ‘)’ 或 ‘?’
-
s [ l , l + i − 1 ] s[l, l+i-1] s[l,l+i−1] 中每个字符 = = = ‘*’ 或 ‘?’
这一步时间复杂度为 O ( k N 2 ) O(kN^2) O(kN2)
3. flag = 2时
d p [ l ] [ r ] [ 2 ] = ∑ i = 1 k d p [ l ] [ r − i ] [ 0 ] dp[l][r][2]=\sum_{i=1}^{k}dp[l][r-i][0] dp[l][r][2]=i=1∑kdp[l][r−i][0]
其中要求
-
s [ l ] = s[l] = s[l]= ‘(’ 或 ‘?’
-
s [ r − i ] = s[r-i] = s[r−i]= ‘)’ 或 ‘?’
-
s [ r − i + 1 , r ] s[r-i+1, r] s[r−i+1,r] 中每个字符 = = = ‘*’ 或 ‘?’
这一步时间复杂度为 O ( k N 2 ) O(kN^2) O(kN2)
总的时间复杂度为 O ( k N 2 ) O(kN^2) O(kN2)
AC code:
#include <bits/stdc++.h>
#define CLEAR(a,val) memset(a, val, sizeof (a))
using ll = long long;
using namespace std;
const ll MOD = 1e9 + 7;
const int MAXN = 501;
int mem[MAXN][MAXN][3];
int mem_case1[MAXN][MAXN];
int main() {
CLEAR(mem, -1);
CLEAR(mem_case1, -1);
int n, k; cin >> n >> k;
string s; cin >> s;
// can s[i] turn into c
auto check = [&s](int i, char c)->bool {
return s[i] == c || s[i] == '?';
};
// trivial case
auto count_case1 = [&](int l, int r) -> int {
if (!check(l, '(')) return 0;
if (!check(r, ')')) return 0;
if (r - l - 1 > k) return 0;
if (mem_case1[l][r] != -1) { return mem_case1[l][r]; }
for(int i = l + 1; i <= r - 1; ++i) {
if (!check(i, '*')) return mem_case1[l][r] = 0;
}
return mem_case1[l][r] = 1;
};
function<ll(int, int, int)> dp = [&](int l, int r, int flag) -> ll {
if (r - l <= 0) return 0;
if (mem[l][r][flag] != -1) { return mem[l][r][flag]; }
ll ans = 0;
if (flag == 0) { // format (~)
if (!check(l, '(') || !check(r, ')')) {
return mem[l][r][flag] = 0;
}
// () = count(l, r)
ans = (ans + count_case1(l, r)) % MOD;
for(int i = 0; i < 3; ++i) {
ans = (ans + dp(l + 1, r - 1, i)) % MOD;
}
// ()any = range dp
for(int i = l + 1; i <= r - 1; ++i) {
if (check(i, ')')) {
ll cnt_left = count_case1(l, i) + dp(l + 1, i - 1, 0) + dp(l + 1, i - 1, 1) + dp(l + 1, i - 1, 2);
ll cnt_right = dp(i + 1, r, 0) + dp(i + 1, r, 1);
cnt_left %= MOD;
cnt_right %= MOD;
ans = (ans + cnt_left * cnt_right % MOD) % MOD;
}
}
}
else if (flag == 1 && check(r, ')')) { // format ***()
for(int i = 1; i <= k; ++i) {
int at = l + i - 1;
if (at + 1 >= r) { break; }
if (!check(at, '*')) { break; }
if (check(at + 1, '(')) {
ans = (ans + dp(at + 1, r, 0)) % MOD;
}
}
}
else if (flag == 2 && check(l, '(')) { // format ()***
for(int i = 1; i <= k; ++i) {
int at = r - i + 1;
if (at - 1 <= l) { break; }
if (!check(at, '*')) { break; }
if (check(at - 1, ')')) {
ans = (ans + dp(l, at - 1, 0)) % MOD;
}
}
}
return mem[l][r][flag] = int(ans % MOD);
};
cout << dp(0, n - 1, 0) << endl;
return 0;
}