题目
Description
Input
Output
一行一个整数表示答案。
Sample Input
样例输入:
3 3
Sample Output
样例输出:
33
Data Constraint
思路
首先我们忽略重复的字符串,定义 ƒ(n)表示长度为 n 的回文串,或由两个回文串拼成的字符串数量。那么可以通过枚举第一个回文串的长度(可以为 0)可以算出f(n)
但是正如刚才所说,会对如 abaaba 这样的字符串重复计算,即会被认为是一个完整的字符串,又会认为时两个 aba 拼接而成。考虑有哪些字符串会被重复计算。
不难发现,当一个回文串是某个字符串重复多次构成时,那么这个字符串就会被计算多 次,如 abaabaaba 这个字符串会由 abaabaaba、aba 和 abaaba、abaaba 和 aba 这3 种情况计算 3 次。
进而可以发现,重复了 k 次的字符串就会被重复计算 k 次,那么定义 g(n)表示长度为n 的回文串,或者只有一种分割方案分割为两个回文串的字符串
那么最后答案就可以通过 g 数组得出
代码
#include<bits/stdc++.h>
#define N 200077
#define mod 998244353
#define ll long long
using namespace std;
ll p[N],g[N],f[N],yjy,n;
int main()
{
freopen("string.in","r",stdin);
freopen("string.out","w",stdout);
scanf("%lld%lld",&n,&p[1]);
p[0]=1,f[1]=g[1]=p[1],(yjy+=p[1]*n)%=mod;
for(int i=2; i<=n; i++) (g[i]+=mod-g[1]*i)%=mod;
for(int i=2; i<=n; i++)
{
p[i]=p[i-1]*p[1]%mod,f[i]=(f[i-2]*p[1]+p[i/2]*p[1]+p[(i+1)/2])%mod,(g[i]+=mod+f[i])%=mod;
for(int j=i*2; j<=n; j+=i) (g[j]+=mod-g[i]*(j/i))%=mod;
(yjy+=g[i]*(n/i))%=mod;
}
printf("%lld",yjy);
}