E - Chaotic Merge
题目描述
摘自 洛谷的翻译
给定两个仅由小写字母组成的字符串 x 和 y。
如果一个序列仅包含 ∣x∣ 个 0 和 ∣y∣ 个 1,则称这个序列为合并序列。
字符串 z 初始为空,按如下规则由合并序列 a 生成:
- 如果 a i = 0 a_i=0 ai=0 ,则把 x 开头的一个字符加到 z 的末尾;
- 如果 a i = 1 a_i=1 ai=1 ,则把 y 开头的一个字符加到 z 的末尾。
两个合并序列 a 和 b 被认为是不同的,如果存在某个 i,使得 a i ≠ b i a_i\neq b_i ai=bi 。
若一个字符串任意两个相邻位置上的字符都不同,则我们称该字符串是混乱的。
定义 f ( l 1 , r 1 , l 2 , r 2 ) f(l_1,r_1,l_2,r_2) f(l1,r1,l2,r2) 表示能从 x 的子串 x [ l 1 , r 1 ] x[l_1,r_1] x[l1,r1] 和 y 的子串 y [ l 2 , r 2 ] y[l_2,r_2] y[l2,r2] 生成混乱的字符串的不同的合并序列的数量,要求子串非空。
求 ∑ 1 ≤ l 1 ≤ r 1 ≤ ∣ x ∣ , 1 ≤ l 2 ≤ r 2 ≤ ∣ y ∣ f ( l 1 , r 1 , l 2 , r 2 ) \sum \limits_{1 \le l_1 \le r_1 \le |x| , 1 \le l_2 \le r_2 \le |y|} f(l_1, r_1, l_2, r_2) 1≤l1≤r1≤∣x∣,1≤l2≤r2≤∣y∣∑f(l1,r1,l2,r2) ,答案对 998244353 取模。
数据范围与提示
1 ≤ ∣ x ∣ , ∣ y ∣ ≤ 1000 1\le |x|,|y|\le 1000 1≤∣x∣,∣y∣≤1000 。
前言
听传言听错了,以为这题挺难。不过,先A了后面一题再杀回来不该得分更高吗?
思路
容易发现混合串其实就是把两个串作为了子序列,所以套路地用插入DP计数即可。
定义 d p [ i ] [ j ] [ 0 / 1 ] dp[i][j][0/1] dp[i][j][0/1] 表示 x x x 串上区间左端点为 1 ∼ i 1\sim i 1∼i ,右端点为 i i i , y y y 串上区间左端点为 1 ∼ j 1\sim j 1∼j ,右端点为 j j j ,混合串的最右端为 0 : x i 0:x_i 0:xi 或 1 : y j 1:y_j 1:yj 的混合串个数。由于插入时要求保证相邻两字符不等,判断只与最后一个字符有关。
转移过程很简单,只需要先把单个 x i x_i xi 或 y j y_j yj 作为混合串的情况算上,然后分别判断 x i ≠ y j x_i\neq y_j xi=yj 、 x i ≠ x i − 1 x_i\neq x_{i-1} xi=xi−1 、 y j ≠ y j − 1 y_j\neq y_{j-1} yj=yj−1 再加上前面的贡献即可。
然而只是这样的话我们会漏掉一个条件,就是区间不能为空。为此,有两种解决办法:
- 把DP再补上两个状态: d p [ i ] [ j ] [ 0 / 1 ] [ f x ] [ f y ] dp[i][j][0/1][fx][fy] dp[i][j][0/1][fx][fy] , f x , f y ∈ { 0 , 1 } fx,fy\in\{0,1\} fx,fy∈{0,1} 分别表示混合串是否包含 x i x_i xi 和 y j y_j yj ,然后转移稍微复杂点。
- 容斥,利用 d p [ i ] [ 0 ] [ 0 ] dp[i][0][0] dp[i][0][0] 和 d p [ 0 ] [ j ] [ 1 ] dp[0][j][1] dp[0][j][1] 来单独算区间为空的情况的方案数,然后减去它。
代码
非容斥(把额外的 [ 2 ] [ 2 ] [2][2] [2][2] 压成了一个 [ 4 ] [4] [4] ):
//By XYX
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 1005
#define ENDL putchar('\n')
#define LL long long
#define DB double
#define lowbit(x) ((-x) & (x))
LL read() {
LL f = 1,x = 0;char s = getchar();
while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
return f * x;
}
const int MOD = 998244353;
int n,m,i,j,s,o,k;
char xx[MAXN],yy[MAXN];
int dp[MAXN][MAXN][4][2];
int main() {
scanf("%s",xx + 1);
n = strlen(xx + 1);
scanf("%s",yy + 1);
m = strlen(yy + 1);
dp[n+1][m+1][3][0] = dp[n+1][m+1][3][1] = 1;
int ans = 0;
for(int i = n+1;i > 0;i --) {
for(int j = m+1;j > 0;j --) {
if(i+j < n+m+2) {
for(int k = 0;k < 4;k ++) {
if(k == 3) {
dp[i][j][k][0] = dp[i][j][k][1] = 1;
}
else dp[i][j][k][0] = dp[i][j][k][1] = 0;
for(int s = 0;s < 2;s ++) {
char pr = (s ? yy[j-1]:xx[i-1]);
if(!(k & (1<<s))) {
pr = 0;
if(k || s) continue;
}
if(i <= n && xx[i] != pr) (dp[i][j][k][s] += dp[i+1][j][k|1][0]) %= MOD;
if(j <= m && yy[j] != pr) (dp[i][j][k][s] += dp[i][j+1][k|2][1]) %= MOD;
}
}
}
// printf("dp[%d][%d] = %d\n",i,j,dp[i][j][0][0]);
if(i <= n && j <= m) {
(ans += dp[i][j][0][0]) %= MOD;
}
}
}
printf("%d\n",ans);
return 0;
}//
容斥:
#include<cstdio>//JZM yyds!!
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<stack>
#include<ctime>
#include<set>
#define ll long long
#define MAXN 1005
#define uns unsigned
#define INF 1e17
#define MOD 998244353ll
#define lowbit(x) ((x)&(-(x)))
using namespace std;
inline ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+s-'0',s=getchar();
return f?x:-x;
}
char a[MAXN],b[MAXN];
int n,m;
ll dp[MAXN][MAXN][2],ans;
inline void add(ll&a,ll b){a=(a+b+MOD)%MOD;}
signed main()
{
scanf("%s%s",a+1,b+1);
n=strlen(a+1),m=strlen(b+1);
for(int i=1;i<=n;i++){
dp[i][0][0]=1;
if(a[i]!=a[i-1])add(dp[i][0][0],dp[i-1][0][0]);
}
for(int i=1;i<=m;i++){
dp[0][i][1]=1;
if(b[i]!=b[i-1])add(dp[0][i][1],dp[0][i-1][1]);
}
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++){
dp[i][j][0]=dp[i][j][1]=1;
if(a[i]!=a[i-1])add(dp[i][j][0],dp[i-1][j][0]);
if(a[i]!=b[j])add(dp[i][j][0],dp[i-1][j][1]);
if(b[j]!=b[j-1])add(dp[i][j][1],dp[i][j-1][1]);
if(b[j]!=a[i])add(dp[i][j][1],dp[i][j-1][0]);
add(ans,dp[i][j][0]+dp[i][j][1]-dp[i][0][0]-dp[0][j][1]);
}
printf("%lld\n",(ans+MOD)%MOD);
return 0;
}