题目看起来简单但比较难,一开始想了一会就想出一种比较类似的做法,但是正解一直没有想出来。后来看了标程和题解,想了好几天才明白。
题意:有一个只由abc构成的字符串,可以进行2种操作,把相邻的两个字母前一个换成后面的或把后面的换成前面的。比如ab操作一次可以变成aa或bb。一个字符串称作“平衡的”指它里面出现abc的个数两两只差不超过1。问经过任意次操作后得到的字符串有多少个是“平衡的”。
思路:对于字符串s,把s的相邻的相同的字符都缩成一个,比如“aaab”缩成“ab”。如果b缩完后的字符串是a缩完后的子序列,那么b能通过一系列操作变成a。然后就可以dp,dp[i][j][k][l]表示s前i个字符构成的有j个a,k个b,l个c的子序列的个数。因为abc个数只差不超过1,所以abc的个数应该都是n/3左右的数,所以先计算出cnt数组,cnt[i][j][k]表示压缩后的字符串中有ijk个abc,能构成多少种不同的“平衡的”原串。nxt[i][j]表示s的第i个字符起字母j第一次出现的位置。dp[i][j][k][l]可以转移到dp[nxt[i+1][0]][j+1][k][l],dp[nxt[i+1][1]][j][k+1][l],dp[nxt[i+1][2]][j][k][l+1],一边dp一边统计答案。
#include<cstdio>
#include<string>
#include<cstring>
#include<utility>
#include<cmath>
#include<map>
#include<queue>
#include<set>
#include<algorithm>
#include<vector>
#include<iostream>
#define ll long long
#define pii pair<int,int>
#define mp make_pair
#define fi first
#define se second
#define inf 0x7fffffff
#define minn(x,y) x=min(x,y)
#define maxx(x,y) x=max(x,y)
using namespace std;
string s;
int mod=51123987;
int dp[160][53][53][53];
int cnt[55][55][55],c[55][55];
int ans;
int nxt[160][3];
void add(int x)
{
ans+=x;
ans%=mod;
}
int get(int i,int j,int k,int x,int y,int z)
{
if(i>x||j>y||k>z)
{
return 0;
}
if((!i&&x)||(!j&&y)||(!k&&z))
{
return 0;
}
if(!i||!j||!z)
{
return 1;
}
return 1ll*c[x][i]*c[y][j]%mod*c[z][k]%mod;
}
int main()
{
int i,j,k,n,m,x,y,z,l;
scanf("%d",&n);
cin>>s;
c[1][1]=1;
for(i=2;i<53;i++)
{
for(j=1;j<53;j++)
{
c[i][j]=(c[i-1][j]+c[i-1][j-1])%mod;
}
}
y=z=x=n/3;
for(i=0;i<53;i++)
{
for(j=0;j<53;j++)
{
for(k=0;k<53;k++)
{
if(n%3==0)
{
cnt[i][j][k]=get(i,j,k,x,y,z);
}
else if(n%3==1)
{
cnt[i][j][k]=(1ll*get(i,j,k,x+1,y,z)+get(i,j,k,x,y+1,z)+get(i,j,k,x,y,z+1))%mod;
}
else
{
cnt[i][j][k]=(1ll*get(i,j,k,x,y+1,z+1)+get(i,j,k,x+1,y,z+1)+get(i,j,k,x+1,y+1,z))%mod;
}
}
}
}
nxt[n-1][0]=nxt[n-1][1]=nxt[n-1][2]=n+1;
nxt[n-1][s[n-1]-'a']=n-1;
for(i=n-2;i>=0;i--)
{
for(j=0;j<3;j++)
{
nxt[i][j]=nxt[i+1][j];
}
nxt[i][s[i]-'a']=i;
}
x+=2;
dp[nxt[0][0]][1][0][0]++;
dp[nxt[0][1]][0][1][0]++;
dp[nxt[0][2]][0][0][1]++;
for(i=0;i<n;i++)
{
for(j=0;j<x;j++)
{
for(k=0;k<x;k++)
{
if(j+k>i+1)
{
break;
}
for(l=0;l<x;l++)
{
if(j+k+l>i+1)
{
break;
}
add(1ll*cnt[j][k][l]*dp[i][j][k][l]%mod);
if(s[i]!='a')
{
dp[nxt[i+1][0]][j+1][k][l]+=dp[i][j][k][l];
dp[nxt[i+1][0]][j+1][k][l]%=mod;
}
if(s[i]!='b')
{
dp[nxt[i+1][1]][j][k+1][l]+=dp[i][j][k][l];
dp[nxt[i+1][1]][j][k+1][l]%=mod;
}
if(s[i]!='c')
{
dp[nxt[i+1][2]][j][k][l+1]+=dp[i][j][k][l];
dp[nxt[i+1][2]][j][k][l+1]%=mod;
}
}
}
}
}
printf("%d",ans);
return 0;
}