链接
easy version
考虑给一个填好的序列(没有问号)我怎么计算深度,做法很容易想到,对于从左往右第 k k k个左括号,我就选择从右往左的第 k k k个右括号和它配对,直到两个指针相遇。
我可以算贡献,对于每个左括号,看看它在多少方案中被统计了。
f i j f_{ij} fij表示前 i i i个字符中恰好有 j j j个左括号的方案数, g i , j g_{i,j} gi,j表示第 i , i + 1 , . . . , n i,i+1,...,n i,i+1,...,n这些字符中有至少 j j j个右括号的方案数
那么答案就是
∑ i = 1 n f i − 1 , j − 1 × g i + 1 , j [ s i ≠ ′ ) ′ ] \sum _{i=1}^n f_{i-1,j-1} \times g_{i+1,j} [s_i\neq ')'] i=1∑nfi−1,j−1×gi+1,j[si=′)′]
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 1000010
#define maxe 1000010
#define cl(x) memset(x,0,sizeof(x))
#define rep(_,__) for(_=1;_<=(__);_++)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
#define mod 998244353ll
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
ll f[2019][2019], g[2019][2019], ans=0;
char s[maxn];
int main()
{
scanf("%s",s+1);
ll n=strlen(s+1), i, j;
f[0][0]=1;
for(i=1;i<=n;i++)
{
for(j=0;j<=n;j++)
{
if(s[i]=='(' and j)
{
f[i][j]+=f[i-1][j-1];
}
if(s[i]==')')
{
f[i][j]+=f[i-1][j];
}
if(s[i]=='?')
{
f[i][j]+=f[i-1][j];
if(j)f[i][j]+=f[i-1][j-1];
}
f[i][j]%=mod;
}
}
g[n+1][0]=1;
for(i=n;i;i--)
{
for(j=0;j<=n;j++)
{
if(s[i]==')')
{
g[i][j]+=g[i+1][max(0ll,j-1)];
}
if(s[i]=='(')
{
g[i][j]+=g[i+1][j];
}
if(s[i]=='?')
{
g[i][j]+=g[i+1][j];
g[i][j]+=g[i+1][max(0ll,j-1)];
}
g[i][j]%=mod;
}
}
for(i=1;i<=n;i++)
{
if(s[i]=='(' or s[i]=='?')
{
for(j=1;j<=n;j++)(ans+=f[i-1][j-1]*g[i+1][j])%=mod;
}
}
cout<<(ans+mod)%mod;
return 0;
}
Hard Version
对于一个 i i i,满足 s i ≠ ′ ) ′ s_i \neq ')' si=′)′
设这个位置左边有 a a a个左括号, b b b个问号;右边有 c c c个右括号, d d d个问号
那么这个位置对答案的贡献是
∑ i = 0 b ∑ j = 0 d ( b i ) ( d j ) [ a + i < c + j ] = ∑ i = 0 b ∑ j = 0 d ( b i ) ( d d − j ) [ i + j < c − a + d ] = ∑ i = 0 b ∑ j = 0 d ( b i ) ( d j ) [ i + j < c − a + d ] \sum_{i=0}^b\sum_{j=0}^d \binom b i \binom d j [a+i<c+j] \\ = \sum_{i=0}^b\sum_{j=0}^d \binom b i \binom d {d-j} [i+j<c-a+d] \\ = \sum_{i=0}^b\sum_{j=0}^d \binom b i \binom d {j} [i+j<c-a+d] i=0∑bj=0∑d(ib)(jd)[a+i<c+j]=i=0∑bj=0∑d(ib)(d−jd)[i+j<c−a+d]=i=0∑bj=0∑d(ib)(jd)[i+j<c−a+d]
这个东西是 ( 1 + x ) b ( 1 + x ) d (1+x)^b(1+x)^d (1+x)b(1+x)d的前面 ( c − a + d ) (c-a+d) (c−a+d)项,也就是多项式 ( 1 + x ) b + d (1+x)^{b+d} (1+x)b+d的前面 ( c − a + d ) (c-a+d) (c−a+d)项
由于 b + d b+d b+d只有两种取值,所以多项式可以预处理,时间复杂度做到 O ( n ) O(n) O(n)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 1000010
#define maxe 1000010
#define mod 998244353ll
#define cl(x) memset(x,0,sizeof(x))
#define rep(_,__) for(_=1;_<=(__);_++)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
ll fact[maxn], _fact[maxn], n, ans, inv[maxn];
char s[maxn];
void prework()
{
ll i;
inv[1]=1;
for(i=2;i<maxn;i++)inv[i]=inv[mod%i]*(mod-mod/i)%mod;
fact[0]=_fact[0]=1;
rep(i,1000000)fact[i]=fact[i-1]*i%mod, _fact[i]=_fact[i-1]*inv[i]%mod;
}
ll C(ll n, ll m)
{
if(m>n)return 0;
return fact[n]*_fact[m]%mod*_fact[n-m]%mod;
}
map<ll,vector<ll>> f;
ll calc(ll a, ll b, ll c, ll d)
{
ll up=min(c-a+d-1,b+d), n=b+d, i;
if(up<0)return 0;
if(f.find(n)!=f.end())return f[n][up];
f[n].resize(n+1);
for(i=0;i<=n;i++)f[n][i]=C(n,i);
for(i=1;i<=n;i++)(f[n][i]+=f[n][i-1])%=mod;
return f[n][up];
}
ll a[maxn], b[maxn], c[maxn], d[maxn];
int main()
{
scanf("%s",s+1);
n=strlen(s+1);
prework();
ll i;
rep(i,n)
{
a[i]=a[i-1]+(s[i]=='(');
b[i]=b[i-1]+(s[i]=='?');
}
for(i=n;i;i--)
{
c[i]=c[i+1]+(s[i]==')');
d[i]=d[i+1]+(s[i]=='?');
}
ll ans=0;
rep(i,n)if(s[i]=='(' or s[i]=='?')(ans+=calc(a[i-1],b[i-1],c[i+1],d[i+1]))%=mod;
cout<<ans;
return 0;
}