#523. 【美团杯2020】半前缀计数
设给定的字符串为s,我们假设 [ 1 , i ] ,[ j, k ] 组成了一个字符串a,为了避免重复计算,我们考虑这个字符串a最后出现是在什么时候。
[ 1 , i ] ,[ j, k ] 组成的字符串a当前是最后一次出现,当且仅当 s [ i + 1 ] != s [ j ]。
得到算法:枚举 i ,计算 [ i + 1 ,len] 中有多少个本质不同的子串且开头不是 s [ i + 1 ]。
这样计算完之后,发现少算了 s [ 1 , 0 ] , s [ 1 , 1 ] , s [ 1 , 2 ] …… s [ 1 , len ]
(一)后缀平衡树求解:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<vector>
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#define ll long long
#define llu unsigned ll
#define pr make_pair
#define pb push_back
//#define lc (cnt<<1)
//#define rc (cnt<<1|1)
#define int ll
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e15;
const int mod=998244353;
const double eps=1e-8;
const double alpha=0.75;
const int maxn=1000100;
const int p=131;
struct node
{
int lc,rc;
int si;
double key;
}t[maxn];
int res[maxn],rcnt,rt;
int height[maxn],ans[maxn];
llu pw[maxn],has[maxn];
char s[maxn];
int *bad=NULL;
pair<double,double>par;
int ass[26];
bool balance(int p)
{
return (double)t[t[p].lc].si<=t[p].si*alpha&&(double)t[t[p].rc].si<=t[p].si*alpha;
}
void pushup(int p)
{
t[p].si=t[t[p].lc].si+t[t[p].rc].si+1;
}
void dfs(int p)
{
if(!p) return ;
dfs(t[p].lc);
res[++rcnt]=p;
dfs(t[p].rc);
}
int build(int l,int r,double lkey,double rkey)
{
if(l>r) return 0;
int mid=(l+r)>>1;
int p=res[mid];
t[p].key=(lkey+rkey)/2;
t[p].lc=build(l,mid-1,lkey,t[p].key);
t[p].rc=build(mid+1,r,t[p].key,rkey);
pushup(p);
return p;
}
void re_build(int *p)
{
rcnt=0;
dfs(*p);
(*p)=build(1,rcnt,par.first,par.second);
}
bool cmp(int x,int y)
{
return s[x]<s[y]||(s[x]==s[y]&&t[x-1].key<t[y-1].key);
}
bool eq(int l1,int l2,int len)
{
int r1=l1+len-1,r2=l2+len-1;
return has[r1]-has[l1-1]*pw[len]==has[r2]-has[l2-1]*pw[len];
}
int getlcp(int x,int y)
{
int l=1,r=min(x,y),ans=0;//二分长度
while(l<=r)
{
int mid=(l+r)>>1;
if(eq(x-mid+1,y-mid+1,mid)) ans=mid,l=mid+1;
else r=mid-1;
}
return ans;
}
void _insert(int &p,int xid,double lkey,double rkey)
{
if(!p)
{
p=xid;
t[p].si=1;
t[p].key=(lkey+rkey)/2;
t[p].lc=t[p].rc=0;
bad=NULL;
return ;
}
if(cmp(xid,p)) _insert(t[p].lc,xid,lkey,t[p].key);
else _insert(t[p].rc,xid,t[p].key,rkey);
pushup(p);
if(!balance(p)) bad=&p,par=pr(lkey,rkey);
}
int findRank(int p,int xid)
{
if(p==xid) return t[t[p].lc].si+1;
if(cmp(xid,p)) return findRank(t[p].lc,xid);
else return t[t[p].lc].si+1+findRank(t[p].rc,xid);
}
int findId(int p,int k)
{
if(!p) return 0;
if(t[t[p].lc].si==k-1) return p;
if(t[t[p].lc].si>=k) return findId(t[p].lc,k);
else return findId(t[p].rc,k-t[t[p].lc].si-1);
}
void __insert(int &rt,int xid)
{
has[xid]=has[xid-1]*p+s[xid];
_insert(rt,xid,0,dnf);
if(bad!=NULL)
re_build(bad);
int k=findRank(rt,xid);
int pre=findId(rt,k-1);
int nt=findId(rt,k+1);
ans[xid]+=height[nt];
height[xid]=getlcp(pre,xid);
height[nt]=getlcp(xid,nt);
ans[xid]-=height[xid];
ans[xid]-=height[nt];
}
void buildTree(int n)
{
for(int i=1;i<=n;i++)
ans[i]+=ans[i-1],__insert(rt,i);
}
signed main(void)
{
scanf("%s",s+1);
int n=strlen(s+1);
reverse(s+1,s+n+1);
pw[0]=1;
for(int i=1;i<=n;i++)
pw[i]=pw[i-1]*p;
buildTree(n);
for(int i=1;i<=n;i++)
ans[i]+=(i+1)*i/2;
reverse(s+1,s+n+1);
reverse(ans+1,ans+n+1);
for(int i=1;i<=n;i++)
ass[s[i]-'a']+=ans[i]-ans[i+1];
int res=ans[1]-ass[s[1]-'a'];
for(int i=1;i<n;i++)
{
res+=ans[i+1];
ass[s[i]-'a']-=ans[i]-ans[i+1];
res-=ass[s[i+1]-'a'];
}
printf("%lld\n",res+n+1);
return 0;
}
(二)后缀自动机求解:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#define ll long long
#define llu unsigned ll
using namespace std;
const int maxn=1000100;
char str[maxn];
struct Sam
{
int last,cnt;
int nt[maxn<<1][26],fa[maxn<<1];
int len[maxn<<1],sum[maxn<<1];
int x[maxn<<1],y[maxn<<1];
ll ans[26];
void init(void)
{
last=1;
cnt=1;
fa[1]=0;
len[1]=0;
}
int num(int p)
{
return len[p]-len[fa[p]];
}
void _insert(int c)
{
int nowp=++cnt,p=last;
len[nowp]=len[last]+1;
while(p&&!nt[p][c]) nt[p][c]=nowp,p=fa[p];
if(!p) fa[nowp]=1,ans[c]+=num(nowp);
else
{
int q=nt[p][c];
if(len[q]==len[p]+1) fa[nowp]=q,ans[c]+=num(nowp);
else
{
int nowq=++cnt;
len[nowq]=len[p]+1;
memcpy(nt[nowq],nt[q],sizeof(nt[q]));
ans[c]-=num(q);
fa[nowq]=fa[q];
fa[nowp]=fa[q]=nowq;
ans[c]+=num(q)+num(nowq)+num(nowp);
while(p&&nt[p][c]==q) nt[p][c]=nowq,p=fa[p];
}
}
last=nowp;
sum[last]=1;
return ;
}
}sam;
int main(void)
{
sam.init();
scanf("%s",str);
int n=strlen(str);
ll ans=0;
for(int i=n-1;i>=0;i--)
{
sam._insert(str[i]-'a');
for(int j=0;j<26;j++)
{
if(j!=str[i]-'a')
ans+=sam.ans[j];
}
}
printf("%lld\n",ans+n+1);
return 0;
}
理解一下这种写法,新加入一个点c后,last指向nowp,新增加的以c结尾的本质不同的子串的个数就是 len(last)- len(fa(last))
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#define ll long long
#define llu unsigned ll
using namespace std;
const int maxn=1000100;
char str[maxn];
struct Sam
{
int last,cnt;
int nt[maxn<<1][26],fa[maxn<<1];
int len[maxn<<1],sum[maxn<<1];
int x[maxn<<1],y[maxn<<1];
ll ans[26];
void init(void)
{
last=1;
cnt=1;
fa[1]=0;
len[1]=0;
}
int num(int p)
{
return len[p]-len[fa[p]];
}
void _insert(int c)
{
int nowp=++cnt,p=last;
len[nowp]=len[last]+1;
while(p&&!nt[p][c]) nt[p][c]=nowp,p=fa[p];
if(!p) fa[nowp]=1;
else
{
int q=nt[p][c];
if(len[q]==len[p]+1) fa[nowp]=q;
else
{
int nowq=++cnt;
len[nowq]=len[p]+1;
memcpy(nt[nowq],nt[q],sizeof(nt[q]));
fa[nowq]=fa[q];
fa[nowp]=fa[q]=nowq;
while(p&&nt[p][c]==q) nt[p][c]=nowq,p=fa[p];
}
}
last=nowp;
sum[last]=1;
ans[c]+=num(last);
return ;
}
}sam;
int main(void)
{
sam.init();
scanf("%s",str);
int n=strlen(str);
ll ans=0;
for(int i=n-1;i>=0;i--)
{
sam._insert(str[i]-'a');
for(int j=0;j<26;j++)
{
if(j!=str[i]-'a')
ans+=sam.ans[j];
}
}
printf("%lld\n",ans+n+1);
return 0;
}