题意
对于S,T两个字符串,构造一个(|S|+|T|)长度的序列 a a a,按以下规则操作:
- 若 a i = 0 a_i=0 ai=0,将当前S的首字母删除,添加到z的末尾
- 若 a i = 1 a_i=1 ai=1,将当前T的首字母删除,添加到z的末尾
若得到的z字符串相邻字母互不相同,则
a
a
a为一个合法序列。
求S任意一个非空连续子串与T任意一个非空连续子串能得到的合法
a
a
a序列数量之和。
题解
对于一个固定的S与T字符串,利用动态规划很容易列出转移方程。设
d
p
[
i
]
[
j
]
[
k
]
dp[i][j][k]
dp[i][j][k]表示当前S已经放入
1
1
1到
i
i
i位,T已经放入第
1
1
1到
j
j
j位,
k
=
0
o
r
1
k=0\ or\ 1
k=0 or 1,
0
0
0表示末尾存放的是S的第i个字符,
1
1
1表示末尾存放的是T的第j个字符,转移方程由题意直接写即可。
然后我就做不下去了…
其实将此DP扩展到任意一个连续子串很简单。设
d
p
[
i
]
[
j
]
[
k
]
dp[i][j][k]
dp[i][j][k]表示当前S处理到第
i
i
i位,T处理到第
j
j
j位,但这里S、T的起始包括了前面任意位置,
k
=
0
o
r
1
k=0\ or\ 1
k=0 or 1,
0
0
0表示末尾存放的是S的第i个字符,
1
1
1表示末尾存放的是T的第j个字符。转移方程式具体见代码。
注意题目要求是S,T非空子串,当
i
>
0
,
j
>
0
i>0,j>0
i>0,j>0时,
d
p
[
i
]
[
j
]
[
k
]
dp[i][j][k]
dp[i][j][k]始终保存的是非空子串的状态。但转移时字符加入一个空串,故另外预处理
d
p
[
i
]
[
0
]
[
0
]
dp[i][0][0]
dp[i][0][0]表示当前只有
s
s
s,最后一位是
s
i
s_i
si的全部状态数。同理,也需要预处理
d
p
[
0
]
[
j
]
[
1
]
dp[0][j][1]
dp[0][j][1]。
时间复杂度为
O
(
∣
S
∣
∣
T
∣
)
O(|S||T|)
O(∣S∣∣T∣)。
哇,感觉没写清楚,具体还是见代码吧。
/*************************************************************************
> File Name: 1.cpp
> Author: Knowledge-Pig
> Mail: 925538513@qq.com
> Blog: https://www.cnblogs.com/Knowledge-Pig/
> Created Time: 2021年03月27日 星期六 14时51分42秒
************************************************************************/
#include<bits/stdc++.h>
using namespace std;
#define For(i,a,b) for(int i=(a);i<=(b);++i)
#define LL long long
#define pb push_back
#define fi first
#define se second
#define pr pair<int,int>
#define mk(a,b) make_pair(a,b)
int read(){
char x=getchar(); int u=0,fg=0;
while(!isdigit(x)){ if(x=='-') fg=1; x=getchar(); }
while(isdigit(x)){ u=(u<<3)+(u<<1)+(x^48); x=getchar(); }
return fg?-u:u;
}
const int mod=998244353;
char s[2021],t[2021];
int n,m,dp[1005][1005][2],ans=0;
void add(int &x,int y){
x+=y;
if(x>=mod) x-=mod;
if(x<0) x+=mod;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("input.in", "r", stdin);
freopen("output.out", "w", stdout);
#endif
scanf("%s%s",s+1,t+1);
n=strlen(s+1); m=strlen(t+1);
int cut=0;
For(i,1,n){
if(s[i]==s[i-1]) cut=i-1;
dp[i][0][0]=i-cut;
}
cut=0;
For(j,1,m){
if(t[j]==t[j-1]) cut=j-1;
dp[0][j][1]=j-cut;
}
For(i,1,n) For(j,1,m){
if(i>1 && s[i]!=s[i-1]) add(dp[i][j][0],dp[i-1][j][0]);
if(i>1 && s[i]!=t[j]) add(dp[i][j][0],dp[i-1][j][1]);
if(j>1 && t[j]!=t[j-1]) add(dp[i][j][1],dp[i][j-1][1]);
if(j>1 && t[j]!=s[i]) add(dp[i][j][1],dp[i][j-1][0]);
if(s[i]!=t[j]){
add(dp[i][j][0],dp[0][j][1]);
add(dp[i][j][1],dp[i][0][0]);
}
add(ans,dp[i][j][0]);
add(ans,dp[i][j][1]);
// printf("%d-%d : %d %d\n",i,j,dp[i][j][0],dp[i][j][1]);
}
cout<<ans<<endl;
return 0;
}