题目
给一个仅由小写字母组成的串S(1<=|S|<=100),
统计本质不同的串T的个数,
满足串T+串T(+号为字符串拼接)是串S的子序列
换言之,TT是S的子序列
答案对998244353取模
思路来源
洛谷题解
ABC299F Square Subsequence - 蒟蒻的狂欢会 - 洛谷博客
题解
图片摘自ABC299F Square Subsequence - 蒟蒻的狂欢会 - 洛谷博客
序列问题,如果不能直接dp得出,首先考虑序列自动机
nex[i][j]预处理出i后面(不包含i本身)第一个j字符的位置
考虑枚举分界线r,将s分成左右两部分[1,r][r+1,n],
假设已经有一个合法T串,在左边最后一个位置是i,在右边最后一个位置是j,
则可以在[i,r]、[j+1,n]中再找到一个公共字母x,
拼接在合法T串的后面,视为dp的增量贡献
但是,如果任取合法公共字母x,会导致字符串计数重复,
所以,只取i、j后面第一个出现的字母x,
即nex[i][x]、nex[j][x],x遍历26个字母即可
初始情况,只需依次考虑26个字母,如考虑到x,当前分割线是r,
则置dp[x在[1,r]里第一个出现的位置][x在[r+1,n]里第一个出现的位置]为1
对于每个分界线r,都要求最后r必取,
这样,由于还限制了第一个串是贪心取的,
假设某个子序列在[1,r]中第一次出现,
且在[r+1,n]也出现,则只会在[1,r]中计数一次
复杂度O(n^3),100*100*100*26≈2e7
代码
代码中,nex[i][j]先求出>=i的第一个字母j的位置,
再用nex[i][j]=nex[i+1][j],求得>i的第一个字母j的位置
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
typedef long long ll;
typedef double db;
typedef pair<int,int> P;
#define fi first
#define se second
#define pb push_back
#define dbg(x) cerr<<(#x)<<":"<<x<<" ";
#define dbg2(x) cerr<<(#x)<<":"<<x<<endl;
#define SZ(a) (int)(a.size())
#define sci(a) scanf("%d",&(a))
#define pt(a) printf("%d",a);
#define pte(a) printf("%d\n",a)
#define ptlle(a) printf("%lld\n",a)
#define debug(...) fprintf(stderr, __VA_ARGS__)
std::mt19937_64 gen(std::chrono::system_clock::now().time_since_epoch().count());
ll get(ll l, ll r) { std::uniform_int_distribution<ll> dist(l, r); return dist(gen); }
const int N=110,M=26,mod=998244353;
int n,nex[N][M],dp[N][N],ans;
char s[N];
void add(int &x,int y){
x=(x+y)%mod;
}
int main(){
scanf("%s",s+1);
n=strlen(s+1);
rep(i,0,M-1)nex[n+1][i]=n+1;
per(i,n,0){
rep(j,0,M-1){
nex[i][j]=nex[i+1][j];
}
nex[i][s[i]-'a']=i;
}
rep(i,0,n){
rep(j,0,M-1){
nex[i][j]=nex[i+1][j];
//printf("i:%d j:%d nex:%d\n",i,j,nex[i][j]);
}
}
rep(r,1,n){
memset(dp,0,sizeof dp);
rep(k,0,M-1){
if(nex[0][k]>r || nex[r][k]>n)continue;
dp[nex[0][k]][nex[r][k]]=1;
//printf("r:%d k:%d fi:%d se:%d\n",r,k,nex[0][k],nex[r][k]);
}
rep(i,1,r){
rep(j,r+1,n){
if(!dp[i][j])continue;
rep(k,0,M-1){
//printf("r:%d i:%d j:%d k:%d nexik:%d nexjk:%d\n",r,i,j,k,nex[i][k],nex[j][k]);
if(nex[i][k]>r || nex[j][k]>n)continue;
add(dp[nex[i][k]][nex[j][k]],dp[i][j]);
}
}
}
rep(j,r+1,n){
//printf("r:%d j:%d dp:%d\n",r,j,dp[r][j]);
add(ans,dp[r][j]);
}
}
pte(ans);
return 0;
}
/*
8
3 7 4 7 3 3 8 2
*/