题解:
考虑设
d
p
i
dp_i
dpi表示以
i
i
i结尾的前缀的划分方案数,因为有2条件的限制,可以得到容斥式子:
d
p
i
=
∑
j
(
−
1
)
C
(
S
j
+
1
,
i
)
d
p
j
dp_i = \sum_{j}(-1)^{C(S_{j+1,i})}dp_j
dpi=j∑(−1)C(Sj+1,i)dpj
C ( S ) C(S) C(S)表示 S S S的最小循环节循环的次数,相当于从前面某个位置转移过来,中间全用相同的串划分,这样一个非法的串就会被容斥掉(考虑会被算多少次即可)。
现在只需要枚举循环串即可,具体怎么处理,可以枚举循环节长度 L L L,然后对于那些 i L iL iL的位置选出来作为关键点,求一下前后缀的lcp,这里用个后缀数组加ST表就可以 O ( 1 ) O(1) O(1)了。然后就可以得到一个三元组 ( L , l , r ) (L,l,r) (L,l,r)表示 a a a在 ( l , r ) (l,r) (l,r)之间, b b b在 ( a , r + L ) (a,r+L) (a,r+L)之间,且 L ∣ ( b − a + 1 ) L|(b-a+1) L∣(b−a+1)的串 S a , b S_{a,b} Sa,b都是循环串,且最小循环节为 L L L。 显然,若有 Z ∣ L Z|L Z∣L,且 ( l , l + L − 1 ) (l,l+L-1) (l,l+L−1)的最小循环节为 Z Z Z,那么 ( L , l , r ) (L,l,r) (L,l,r)这个三元组就不合法。具体怎么实现可以参考代码。
然后这些三元组的贡献,发现对于某个前缀,一个三元组如果包含他,则是个等差数列,可以用前缀和 O ( 1 ) O(1) O(1)算出来,然后考虑每个三元组一共对多少个前缀产生贡献,显然是 ∑ ( r − l + 1 − L ∗ 2 ) \sum(r-l+1-L*2) ∑(r−l+1−L∗2),求个和发现是求最小循环节循环次数等于2的串,根据Runs Theorem,这个的复杂度是 O ( n log n ) O(n \log n) O(nlogn)级别的(具体证明请看题解),于是我们就在 O ( n log n ) O(n \log n) O(nlogn)的时间内解决了这个问题。
#include <bits/stdc++.h>
using namespace std;
typedef pair <int,int> pii;
const int N=1e6+50, L=21, mod=998244353;
inline int add(int x,int y) {return (x+y>=mod) ? (x+y-mod) : (x+y);}
inline int dec(int x,int y) {return (x-y<0) ? (x-y+mod) : (x-y);}
inline void up(int &x,int y) {x=add(x,y);}
inline void dn(int &x,int y) {x=dec(x,y);}
int n,lg[N]; char s[N];
struct SA {
int sa[N],a[N],b[N],c[N],*rk=a,*sa2=b,m,h[L][N];
inline void Rsort() {
for(int i=1;i<=m;i++) c[i]=0;
for(int i=1;i<=n;i++) ++c[rk[i]];
for(int i=1;i<=m;i++) c[i]+=c[i-1];
for(int i=n;i>=1;i--) sa[c[rk[sa2[i]]]--]=sa2[i];
}
inline void init() {
for(int i=1;i<=n;i++) sa2[i]=i, rk[i]=s[i]-'a'+1;
m=26; Rsort();
for(int w=1;w<=n;w<<=1) {
int p=0;
for(int i=n-w+1;i<=n;i++) sa2[++p]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) sa2[++p]=sa[i]-w;
Rsort(); swap(rk,sa2); rk[sa[1]]=p=1;
for(int i=2;i<=n;i++)
rk[sa[i]]=(sa2[sa[i]]==sa2[sa[i-1]] && sa2[sa[i]+w]==sa2[sa[i-1]+w]) ? p : ++p;
m=p; p=0; if(m==n) break;
}
for(int i=1,k=0,j;i<=n;h[0][rk[i++]]=k)
for(k?k--:k,j=sa[rk[i]-1];s[i+k]==s[j+k];++k);
for(int i=1;i<L;i++)
for(int j=1;j+(1<<i)-1<=n;++j)
h[i][j]=min(h[i-1][j],h[i-1][j+(1<<(i-1))]);
}
inline int ask(int x,int y) {
x=rk[x], y=rk[y];
if(x>y) swap(x,y);
++x; int l=lg[y-x+1];
return min(h[l][x],h[l][y-(1<<l)+1]);
}
} ori,rev;
const int T=2e7+50;
int tot,cov[T],si[N];
int f[N],prd[T][2];
vector <pii> g[N];
int main() {
scanf("%s",s+1); n=strlen(s+1);
for(int i=2;i<=n;i++) lg[i]=lg[i>>1]+1;
ori.init();
reverse(s+1,s+n+1);
rev.init();
for(int i=1;i<=n;i++) si[i]=si[i-1]+n/i;
for(int i=1;i<=n;i++) {
int l=0, r=0;
for(int j=i;j+i<=n;j+=i) if(j>r) {
int L=rev.ask(n-j+1,n-(j+i)+1), R=ori.ask(j,j+i);
l=j-L+1, r=j+R-1;
if(r-l+1<i) continue;
if(cov[si[i-1]+j/i]) continue;
for(int x=i+i;l+x+x-1<=r+i;x+=i)
cov[si[x-1]+(l+x-1)/x]=1;
int last=tot+1;
for(int k=l;k<=r-i+1;k++) {
if(k<l+i) ++tot;
g[k+2*i-1].push_back(pii(i,(k-l)%i+last));
}
}
}
int ans; f[0]=ans=1;
for(int i=1;i<=n;i++) {
f[i]=ans;
for(auto v:g[i]) {
int x=v.second, l=v.first;
up(prd[x][(i/l)&1],f[i-2*l]);
dn(f[i],prd[x][(i/l)&1]*2%mod);
} ans=add(ans,f[i]);
} cout<<f[n]<<'\n';
}