dp[i][j]表示区间[i,j]的答案。
如果区间长度为len,那么这个区间的答案一定是经过len-1次操作得到的。
所以我们从长度为len的区间转移向长度为len+1的区间时,需要枚举一个符号,这个符号就是第len次操作使用的符号。每个符号求到的答案加起来就是这个区间的答案。
假设现在我们关注符号p。符号左边有a种情况,第i种为A[i],符号右边有b种情况,第j种为B[j]。
如果符号p是加号或者减号,那么我们要的答案就是∑∑(A[i]±B[j])=b∑A[i]±a∑B[j]。其中∑A[i]就是左边区间的答案,∑B[j]就是右边区间的答案,所以我们只要计算出a,b就可以转移了。如果一个区间的长度是len,即len-1个符号,那么这个区间的方案总数就是A(len-1,len-1),即(len-1)!。预处理出所有阶乘然后就可以O(1)计算a,b了。
还需要考虑的一点是,因为第len个使用的符号是p,那么p左边的符号和p右边的符号之前都是隔离开的,而且必须要有一个顺序。由于这两个区间的答案分别都已经包含了自己的顺序,但是两个区间之间的顺序却没有被计算过,因此答案还要再乘以c。c可以理解为len+1个两种元素的全排列,也可以理解为在len+1和位子中挑一些给左边,剩下的给右边。两种思路都可以计算出c。
如果符号p是乘号,那么我们要的答案就是∑∑(A[i]*B[j])=∑A[i]*∑B[j]。即两边区间的答案直接乘起来。
这里同样需要乘以一个c。
希望自己以后可以思考清楚状态的定义,当前状态和子状态之间的关系。
比如在本题中,dp[i][j]代表区间[i,j]的答案。我们会枚举一个符号,然后连接左右连个子状态。
我们就应该想到,枚举的这个符号就是最后一次使用的符号,但是对之前的符号顺序并没有要求。
而我们在转移的过程中,使用到了子状态的结果,而子状态的结果是完全包含这个子状态的所有顺序的。
但是我们在将两个子状态结合在一起的时候,应该要考虑这两个子状态之间的顺序关系。
即要考虑清楚,本状态的顺序要求,子状态的顺序状况,本状态与子状态的顺序关系,子状态之间的顺序关系。
代码
#include<stdio.h>
#include<algorithm>
#include<string.h>
using namespace std;
const int maxn = 110;
const int mod = 1e9+7;
int n;
int a[maxn];
int dp[maxn][maxn];
char s[maxn];
int pow2[maxn];
int inv[maxn];
int get(int a,char o,int b)
{
if(o=='+') return (1ll*a+b)%mod;
else if(o=='-') return (1ll*a-b+mod)%mod;
else return 1ll*a*b%mod;
}
void read()
{
for(int i=0;i<n;i++)
scanf("%d",a+i);
scanf("%s",s);
}
int mp(int x,int n)
{
int ret=1;
while(n)
{
if(n&1) ret=1ll*ret*x%mod;
x=1ll*x*x%mod;
n>>=1;
}
return ret;
}
void solve()
{
read();
memset(dp,0,sizeof(dp));
for(int i=0;i<n;i++)
dp[i][i]=a[i];
for(int len=2;len<=n;len++)
for(int l=0;l<n-len+1;l++)
{
int r = l+len-1;
for(int p=l;p<r;p++)
if(s[p]=='*') dp[l][r]=(1ll*dp[l][r]+1ll*pow2[len-2]*inv[p-l]%mod*inv[r-p-1]%mod*get(dp[l][p],s[p],dp[p+1][r])%mod)%mod;
else dp[l][r]=(1ll*dp[l][r]+1ll*pow2[len-2]*inv[p-l]%mod*inv[r-p-1]%mod*get(1ll*pow2[r-p-1]*dp[l][p]%mod,s[p],1ll*pow2[p-l]*dp[p+1][r]%mod)%mod)%mod;
//printf("%d %d %d\n",l,r,dp[l][r]);
}
printf("%d\n",dp[0][n-1]);
}
void init()
{
pow2[0]=1;
pow2[1]=1;
inv[0]=1;
inv[1]=1;
for(int i=2;i<maxn;i++)
{
pow2[i]=1ll*pow2[i-1]*i%mod;
inv[i]=mp(pow2[i],mod-2);
}
}
int main()
{
init();
while(~scanf("%d",&n)) solve();
return 0;
}