题目大意
给定一个长度为 n n n 的、仅含 a,b 的字符串 s s s,每次可以对 s s s 做下列操作:
- 在任意位置添加或删除 aa;
- 在任意位置添加或删除 bbb;
- 在任意位置添加或删除 ababab。
问 s s s 能变成多少种长度为 x x x 的字符串,答案模 998244353 998244353 998244353。
n ≤ 3 × 1 0 5 , x ≤ 1 0 9 n \le 3 \times 10^5,\ x \le 10^9 n≤3×105, x≤109
\\
\\
\\
题解
官方题解直接抛出了个 Tetrahedral symmetry。。。这玩意到底是怎么跟四面体旋转扯上关系的啊
以下是比较算法竞赛的思路。
就是每个字符串的最小表示竟然是唯一的。具体来说,给定一个字符串,它能唯一转化成一个长度在 4 以内的串。
而这种基础串只有 12 种。
证明的话因为题解啥也没写,不知道这鬼东西是不是有什么绝妙的证明,所以我只能脑补一个马后炮的归纳证法。首先手撸长度为 5 以内的所有串,发现它们确实能且仅能转化成 12 种基本串,然后归纳法,假设串长为 n n n 时结论成立,串长为 n + 1 n+1 n+1 时可以先对前 n n n 位做操作使其转化成 4 以内的基本串,然后再加上第 n + 1 n+1 n+1 位的字符形成一个长度为 5 的串,由手撸结果,它能转化为基本串。
这个归纳证明也给出了求一个串转化成基本串(即最小表示)的方法:从左往右一位一位地添加字符,若当前字符串能转化成基本串,则转化。只需一开始把手撸的长度在 5 以内的转化规则全部记在 map 里即可。
最后只需一个矩阵快速幂求出各种基本串在长度为 x x x 时有多少对应的字符串即可。具体来说,做一个转移矩阵 C C C, C i j C_{ij} Cij 表示基本串 i i i 在末尾添加一个字符(a 或 b)转化成基本串 j j j 的方案数。
代码
#include<bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;i++)
using namespace std;
typedef long long LL;
const int maxtot=13;
const LL mo=998244353;
int n,x;
string s;
unordered_map<string,int> M;
unordered_map<string,string> trans;
int tot;
void add(string s) {M[s]=++tot;}
int getMinStr(string s)
{
string t;
for(auto c:s)
{
t+=c;
if (trans.count(t)) t=trans[t];
}
return M[t];
}
struct ARR{
LL n[maxtot][maxtot];
} re;
ARR operator * (const ARR &a,const ARR &b)
{
fo(i,1,tot)
fo(j,1,tot)
{
re.n[i][j]=0;
fo(k,1,tot) (re.n[i][j]+=a.n[i][k]*b.n[k][j])%=mo;
}
return re;
}
ARR C,Ans;
void Pow(ARR x,int y)
{
for(; y; y>>=1, x=x*x) if (y&1) Ans=Ans*x;
}
void Pre()
{
add("");
add("a");
add("ab");
add("aba");
add("abb");
add("b");
add("ba");
add("bab");
add("babb");
add("bb");
add("bba");
add("bbab");
trans["aa"]=trans["bbb"]="";
trans["abaa"]="ab";
trans["abab"]="bba";
trans["abba"]="bab";
trans["abbb"]="a";
trans["baa"]="b";
trans["baba"]="abb";
trans["babba"]="bbab";
trans["babbb"]="ba";
trans["bbaa"]="bb";
trans["bbaba"]="babb";
trans["bbabb"]="aba";
for(auto p:M)
{
C.n[p.second][getMinStr(p.first+'a')]++;
C.n[p.second][getMinStr(p.first+'b')]++;
}
}
int main()
{
Pre();
cin >> n >> s >> x;
Ans.n[1][1]=1;
Pow(C,x);
printf("%lld\n",Ans.n[1][getMinStr(s)]);
}