题意:
给你一个长度为n且只包含数字的字符串S,问你有多少个二元组字符串(A,B)使得A和B都为S的子串并且B恰好是A+1(比如B是1,A是0,B是19,A是18等等)。
思路:
仔细思考一下什么情况下A+1==B,发现首先需要一个公共的前缀(我们下面称它为x),当不产生进位的时候,显而易见满足要求的字符串的形式为:xp==x(p+1),例如21+1==22,3+1==4;当产生进位的时候该怎么办呢,由于+1进位,那么A字符串很明显需要有9999...这样的后缀,B字符串需要有0000...这样的后缀,因此我们列出满足条件的形式:xp9999...+1==x(p+1)0000...&&0的个数==9的个数。这时候我们会发现,不产生进位的情况就是产生进位情况时0/9的个数为0时的特殊情况,因此我们合并这两种情况,最终列出需要满足的式子:xp9999...+1==x(p+1)0000...&&0的个数==9的个数。
那么如何求得答案呢?我们可以想到一种做法就是枚举前缀x,然后看它在S中有多少个满足xp9999...+1==x(p+1)0000...&&0的个数==9的个数的子串。
考虑建立后缀自动机,枚举x的操作可以转换为枚举SAM中的结点,接着我们再枚举p,看它是否能同时往p和(p+1)走,如果能,我们再看它能否同时往p9和(p+1)0走,以此类推,注意边走边计数即可。计数的时候我们用len[now]-len[fa[now]]算出这个结点代表的本质不同字符串的数量,接着再乘以它们出现次数,即siz。注意前缀x可以为空,因此我们需要特判一下1这个结点,因为它代表空串。
计数的代码:
for(int i=0;i<=8;i++)
{
int x=nxt[now][i],y=nxt[now][i+1];
ll l=(len[now]-len[fa[now]]);
if(now==1)
l++;
while(x&&y)
{
ans+=l*siz[x]*siz[y];
x=nxt[x][9],y=nxt[y][0];
}
}
总的代码:
#include<bits/stdc++.h>
#define ac return 0
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define endl "\n"
string yes="Yes\n";
string no="no\n";
const int N=1e6+10;
const int MOD=1e9+7;
int last,cnt,fa[N],nxt[N][15];
ll siz[N],len[N];
vector<int>g[N];
struct SAM
{
void init()
{
last=cnt=1;
fa[1]=len[1]=0;
}
int newnode()
{
cnt++;
fa[cnt]=len[cnt]=0;
return cnt;
}
void add(char x)
{
int c=x-'0';
int p=last;
int np=newnode();
siz[np]=1;
len[np]=len[p]+1;
last=np;
while(p&&!nxt[p][c])
{
nxt[p][c]=np;
p=fa[p];
}
if(!p)
fa[np]=1;
else
{
int q=nxt[p][c];
if(len[q]==len[p]+1)
fa[np]=q;
else
{
int nq=++cnt;
len[nq]=len[p]+1;
for(int i=0;i<=9;i++)
{
nxt[nq][i]=nxt[q][i];
}
fa[nq]=fa[q];
fa[np]=fa[q]=nq;
while(nxt[p][c]==q)
{
nxt[p][c]=nq;
p=fa[p];
}
}
}
}
}sam;
void dfs1(int now)
{
for(auto u:g[now])
{
dfs1(u);
siz[now]+=siz[u];
}
}
ll ans;
void dfs2(int now)
{
for(int i=0;i<=8;i++)
{
int x=nxt[now][i],y=nxt[now][i+1];
ll l=(len[now]-len[fa[now]]);
if(now==1)
l++;
while(x&&y)
{
ans+=l*siz[x]*siz[y];
x=nxt[x][9],y=nxt[y][0];
}
}
for(auto u:g[now])
{
dfs2(u);
}
}
void solve()
{
sam.init();
int n;
cin>>n;
string s;
cin>>s;
s='#'+s;
for(int i=1;i<=n;i++)
{
sam.add(s[i]);
}
for(int i=1;i<=cnt;i++)
{
g[fa[i]].push_back(i);
}
dfs1(1);
dfs2(1);
cout<<ans;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int t=1;
// cin>>t;
while(t--)
{
solve();
}
ac;
}