C F 613 E P u z z l e L o v e r \mathrm{CF613E\ Puzzle \ Lover} CF613E Puzzle Lover
- 头铁地去做一道 ∗ 3200 *3200 ∗3200的题,没想到还是比较好想~~(看题解~~
题目意思
-
就是给你一个 2 × n 2\times n 2×n的字符矩阵,以及给你个 m m m的字符串 t t t。问你有多少种方案使得路径上的字符连接起来的字符串为 t t t(每次可以往上下左右走,但不能走到同一格)。答案对 1 e 9 + 7 1e9+7 1e9+7取模
-
n , m ≤ 2 e 3 n,m\leq 2e3 n,m≤2e3
S o l \mathrm{Sol} Sol
-
前置知识:字符串哈希+DP,是不是 P J PJ PJ内容啊。
-
我们设这条路径的端点为 s , t s,t s,t。那么一条 s → t s\to t s→t的路径样子如下:
-
部分一: 从 s s s往左走,再往右绕回来
-
部分二: 上下往右边移动
-
部分三: 绕道 t t t的右边再绕回来
-
然后我们对于 1 , 3 1,3 1,3部分字符串哈希。这个很简单,即枚举终点用哈希判断是否合法。以及为中间部分做 d p dp dp预处理。不在过多讲。
-
关键是中间反复横跳的过程,设 f i , j , k f_{i,j,k} fi,j,k表示到 ( i , j ) (i,j) (i,j)配对到 t t t中第 k k k个字符,从 f i , j − 1 , k − 1 f_{i,j-1,k-1} fi,j−1,k−1转移过来。同理 g i , j , k g_{i,j,k} gi,j,k表示到 ( i , j ) (i,j) (i,j)配对到 t t t中第 k k k个字符,从 g 3 − i , j , k − 1 g_{3-i,j,k-1} g3−i,j,k−1转移过来。( 3 − i 3-i 3−i即在第一行与第二行反复横跳的简单转换)。
-
那么转移就很简单(不会往回走)。所以 f , g f,g f,g之间相互更新即可。具体的:
- f i , j + 1 , k + 1 = ∑ g i , j , k [ j ≤ n , c h i , j + 1 = t k + 1 ] f_{i,j+1,k+1}=\sum g_{i,j,k}[j\leq n,ch_{i,j+1}=t_{k+1}] fi,j+1,k+1=∑gi,j,k[j≤n,chi,j+1=tk+1]
- f i , j + 1 , k + 1 = ∑ f i , j , k [ j ≤ n , c h i , j + 1 = t k + 1 ] f_{i,j+1,k+1}=\sum f_{i,j,k}[j\leq n,ch_{i,j+1}=t_{k+1}] fi,j+1,k+1=∑fi,j,k[j≤n,chi,j+1=tk+1]
- g 3 − i , j , k + 1 = ∑ f i , j , k [ c h 3 − i , j = t k + 1 ] g_{3-i,j,k+1}=\sum f_{i,j,k}[ch_{3-i,j}=t_{k+1}] g3−i,j,k+1=∑fi,j,k[ch3−i,j=tk+1]
-
对于第三部分我们来统计答案。我们枚举终点算出起点,那么 a n s = ∑ f i , j − 1 , m − 1 + ∑ g i , j − 1 , m − 1 ans=\sum f_{i,j-1,m-1}+\sum g_{i,j-1,m-1} ans=∑fi,j−1,m−1+∑gi,j−1,m−1(为什么是 j − 1 j-1 j−1因为第二部分的结束的那一列加一也是 t t t所在的那一列)。
-
我们还存在 s , t s,t s,t在同一列的情况,此时把字符串翻转的答案是重的所以要删掉。同时也要加上这种情况的贡献(即代码中的 g s gs gs,此时不存在第二部分的贡献)
-
于是我们做出了一道 3200 3200 3200的题。
C o d e \mathrm{Code} Code
#include <bits/stdc++.h>
#define pb push_back
#define int long long
using namespace std;
inline int read()
{
int sum=0,ff=1; char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-') ff=-1;
ch=getchar();
}
while(isdigit(ch))
sum=sum*10+(ch^48),ch=getchar();
return sum*ff;
}
const int N=2005;
const int mo=1e9+7;
const int base=131;
char ch[3][N],t[N];
int Hash1[3][N],Hash2[3][N],Hasht[N],bin[N];
int n,m,f[3][N][N],g[3][N][N],ans,gs;
inline int get1(int l,int r,int id)
{
return Hash1[id][r]-Hash1[id][l-1]*bin[r-l+1];
}
inline int get2(int l,int r,int id)
{
return Hash2[id][l]-Hash2[id][r+1]*bin[r-l+1];
}
inline void Add(int &x,int y)
{
x+=y;
if(x>=mo) x-=mo;
}
inline int merge(int H1,int H2,int len)
{
return H2+H1*bin[len];
}
inline void dp()
{
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
for ( int i=1;i<=m;i++ )
Hasht[i]=Hasht[i-1]*base+(t[i]-'a'+1);
for ( int i=1;i<=2;i++ )
for ( int j=1;j<=n;j++ )
for ( int k=j;k>=1;k-- )
{
int len=2*(j-k+1);
if(len>m) continue;
int HA=get1(k,j,i);
int HB=get2(k,j,3-i);
int HH=Hasht[len];
if(merge(HB,HA,j-k+1)==HH)
{
if(len==m) gs++;
else g[i][j][len]=1;
}
}
for ( int i=1;i<=2;i++ )
for ( int j=1;j<=n;j++ )
if(ch[i][j]==t[1]) g[i][j][1]=1;
for ( int i=m;i>=1;i-- )
Hasht[i]=Hasht[i+1]*base+(t[i]-'a'+1);
for ( int k=1;k<m;k++ )
for ( int j=1;j<=n;j++ )
for ( int i=1;i<=2;i++ )
{
if(g[i][j][k])
if(j<n&&ch[i][j+1]==t[k+1])
Add(f[i][j+1][k+1],g[i][j][k]);
if(f[i][j][k])
{
if(j<n&&ch[i][j+1]==t[k+1])
Add(f[i][j+1][k+1],f[i][j][k]);
if(ch[3-i][j]==t[k+1])
Add(g[3-i][j][k+1],f[i][j][k]);
}
}
for ( int i=1;i<=2;i++ )
for ( int j=1;j<=n;j++ )
{
if(ch[i][j]==t[m])
{
if(m==1) Add(ans,1);
Add(ans,f[i][j-1][m-1]);
Add(ans,g[i][j-1][m-1]);
}
for ( int k=j;k<=n;k++ )
{
int len=(k-j+1)*2;
if(len>m) continue;
int HA=get1(j,k,3-i);
int HB=get2(j,k,i);
int HH=Hasht[m-len+1];
if(merge(HA,HB,k-j+1)==HH)
{
if(len==m&&m!=2) gs++;
else
{
Add(ans,f[i][j-1][m-len]);
Add(ans,g[i][j-1][m-len]);
}
}
}
}
}
signed main()
{
scanf("%s",ch[1]+1);
scanf("%s",ch[2]+1);
n=strlen(ch[1]+1);
scanf("%s",t+1);
m=strlen(t+1);
bin[0]=1;
int Lim=max(n,m);
for ( int i=1;i<=Lim;i++ )
bin[i]=bin[i-1]*base;
for ( int i=1;i<=2;i++ )
for ( int j=1;j<=n;j++ )
Hash1[i][j]=Hash1[i][j-1]*base+(ch[i][j]-'a'+1);
for ( int i=1;i<=2;i++ )
for ( int j=n;j>=1;j-- )
Hash2[i][j]=Hash2[i][j+1]*base+(ch[i][j]-'a'+1);
dp();
reverse(t+1,t+m+1);
dp();
Add(ans,gs/2);
if(m==1) ans/=2;
printf("%lld\n",ans);
return 0;
}