绝对是我写过最长的一份代码了.
这个快敲吐了.
通过这道题能 get 到一个套路:
两颗树同时统计信息的题可以考虑在个树上跑边分治,把点扔到另一颗树的虚树上,然后跑虚树DP.
具体地,这道题中我们发现 $LCP$ 长度是反串后缀树 $LCA$ 深度,$LCS$ 是正串后缀树 $LCA$ 深度.
我们建出正反两串后缀树后,将长度大于 K1/K2 的点的深度置为 0,然后跑一个边分+虚树即可.
code:
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
#include <vector>
#include <map>
#define N 200007
#define inf 0x3f3f3f3f
#define ull unsigned long long
// 代码已写完,人已阵亡.
using namespace std;
int bug;
int K1,K2;
ull ans,W;
char S[N];
namespace IO {
void setIO(string s)
{
string in=s+".in";
string out=s+".out";
freopen(in.c_str(),"r",stdin);
// freopen(out.c_str(),"w",stdout);
}
};
struct SAM {
#define M N<<1
int tot,last;
struct Edge {
int to,w;
Edge(int to=0,int w=0):to(to),w(w){}
};
vector<Edge>G[M];
int pre[M],ch[M][26],mx[M],str_sam[M],sam_str[M],depth[M];
void Initialize() { tot=last=1; }
void extend(int c)
{
int np=++tot,p=last;
mx[np]=mx[p]+1,last=np;
for(;p&&!ch[p][c];p=pre[p]) ch[p][c]=np;
if(!p) pre[np]=1;
else
{
int q=ch[p][c];
if(mx[q]==mx[p]+1) pre[np]=q;
else
{
int nq=++tot;
mx[nq]=mx[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
pre[nq]=pre[q],pre[np]=pre[q]=nq;
for(;p&&ch[p][c]==q;p=pre[p]) ch[p][c]=nq;
}
}
}
void Build_LCP()
{
int n=strlen(S+1),i,j,p=1;
for(i=1;i<=n;++i)
{
p=ch[p][S[n-i+1]-'a'];
sam_str[p]=n-i+1;
str_sam[n-i+1]=p;
}
for(i=2;i<=tot;++i)
{
if(mx[i]>K1) depth[i]=0;
else depth[i]=mx[i];
}
for(i=2;i<=tot;++i) G[pre[i]].push_back(Edge(i,depth[i]-depth[pre[i]]));
}
void Build_LCS()
{
int n=strlen(S+1),i,j,p=1;
for(i=1;i<=n;++i)
{
p=ch[p][S[i]-'a'];
sam_str[p]=i;
str_sam[i]=p;
}
for(i=2;i<=tot;++i)
{
if(mx[i]>K2) depth[i]=0;
else depth[i]=mx[i];
}
for(i=2;i<=tot;++i) G[pre[i]].push_back(Edge(i,depth[i]-depth[pre[i]]));
}
#undef M
}lcp,lcs;
namespace vir {
vector<int>G[N<<2];
vector<int>clr;
int t,sta,tot;
int is1[N<<2],is2[N<<2];
int dfn[N<<2],dep[N<<2],size[N<<2],son[N<<2],top[N<<2],f[N<<2];
int S[N<<2],val[N<<2],re[N<<2];
ull size1[N<<2],size2[N<<2];
ull sum1[N<<2],sum2[N<<2];
bool cmp(int a,int b)
{
return dfn[a]<dfn[b];
}
void get_dfn(int x,int fa)
{
dfn[x]=++t;
size[x]=1;
f[x]=fa;
for(int i=0;i<lcp.G[x].size();++i)
{
int y=lcp.G[x][i].to;
if(y==fa) continue;
dep[y]=dep[x]+1;
get_dfn(y,x);
size[x]+=size[y];
if(size[y]>size[son[x]]) son[x]=y;
}
}
void dfs2(int u,int tp)
{
top[u]=tp;
if(son[u]) dfs2(son[u],tp);
for(int i=0;i<lcp.G[u].size();++i)
{
int v=lcp.G[u][i].to;
if(v==son[u]||v==f[u]) continue;
dfs2(v,v);
}
}
int LCA(int x,int y)
{
while(top[x]!=top[y])
{
dep[top[x]]>dep[top[y]]?x=f[top[x]]:y=f[top[y]];
}
return dep[x]<dep[y]?x:y;
}
void _new(int x,int v,int c)
{
++tot;
re[tot]=x;
val[x]=v;
if(c==1) is1[x]=1;
else is2[x]=1;
}
void addvir(int x,int y)
{
G[x].push_back(y);
}
void Initialize()
{
t=0;
get_dfn(1,0);
dfs2(1,1);
}
void Insert(int x)
{
if(sta<=1)
{
S[++sta]=x;
return;
}
int lca=LCA(S[sta],x);
if(lca==S[sta]) S[++sta]=x;
else
{
while(sta>1&&dep[S[sta-1]]>=dep[lca]) addvir(S[sta-1],S[sta]),--sta;
if(S[sta]==lca) S[++sta]=x;
else
{
addvir(lca,S[sta]);
S[sta]=lca;
S[++sta]=x;
}
}
}
void Build()
{
sta=0;
sort(re+1,re+1+tot,cmp);
if(re[1]!=1) S[++sta]=1;
for(int i=1;i<=tot;++i) Insert(re[i]);
while(sta>1) addvir(S[sta-1],S[sta]),--sta;
}
void DP(int x)
{
clr.push_back(x);
for(int i=0;i<G[x].size();++i)
{
int y=G[x][i];
DP(y);
size1[x]+=size1[y];
size2[x]+=size2[y];
sum1[x]+=sum1[y];
sum2[x]+=sum2[y];
}
ull tmp=0;
ull cntw=0;
ull cur=0;
for(int i=0;i<G[x].size();++i)
{
int y=G[x][i];
tmp+=(sum1[x]-sum1[y])*size2[y];
tmp+=(sum2[x]-sum2[y])*size1[y];
cntw+=(size1[x]-size1[y])*size2[y];
}
cur+=tmp*lcp.depth[x];
cur-=cntw*W*lcp.depth[x];
if(is1[x])
{
cur+=size2[x]*val[x]*lcp.depth[x];
cur+=sum2[x]*lcp.depth[x];
cur-=size2[x]*W*lcp.depth[x];
}
if(is2[x])
{
cur+=size1[x]*val[x]*lcp.depth[x];
cur+=sum1[x]*lcp.depth[x];
cur-=size1[x]*W*lcp.depth[x];
}
ans+=cur/2;
size1[x]+=is1[x];
size2[x]+=is2[x];
sum1[x]+=is1[x]*val[x];
sum2[x]+=is2[x]*val[x];
G[x].clear();
}
void solve()
{
Build();
DP(1);
for(int i=0;i<clr.size();++i)
{
int x=clr[i];
val[x]=sum1[x]=sum2[x]=size1[x]=size2[x]=is1[x]=is2[x]=0;
}
for(int i=1;i<=tot;++i)
{
re[i]=0;
}
tot=0;
sta=0;
clr.clear();
}
}; // 虚树
int tot,edges=1;
int totsize,rt1,rt2,mx,ed,lsc,rsc;
int hd[N<<2],vis[N<<3],size[N<<2];
struct Edge {
int to,w,nex;
}e[N<<3];
struct Node {
int u,dis,val;
Node(int u=0,int dis=0,int val=0):u(u),dis(dis),val(val){}
}L[N<<2],R[N<<2];
void add_div(int x,int y,int z)
{
e[++edges].nex=hd[x],hd[x]=edges,e[edges].to=y,e[edges].w=z;
}
void Build_Tree(int x,int fa)
{
int ff=0;
for(int i=0;i<lcs.G[x].size();++i)
{
int y=lcs.G[x][i].to;
if(y==fa) continue;
if(!ff)
{
ff=x;
add_div(ff,y,lcs.G[x][i].w);
add_div(y,ff,lcs.G[x][i].w);
}
else
{
++tot;
add_div(ff,tot,0);
add_div(tot,ff,0);
add_div(tot,y,lcs.G[x][i].w);
add_div(y,tot,lcs.G[x][i].w);
ff=tot;
}
Build_Tree(y,x);
}
}
void find_edge(int x,int fa)
{
size[x]=1;
for(int i=hd[x];i;i=e[i].nex)
{
int y=e[i].to;
if(y==fa||vis[i]) continue;
find_edge(y,x);
int now=max(size[y],totsize-size[y]);
if(now<mx)
{
mx=now;
ed=i;
rt1=y;
rt2=x;
}
size[x]+=size[y];
}
}
void get_node(int x,int fa,int dep,int ty)
{
if(ty==1)
{
if(lcs.sam_str[x]) L[++lsc]=Node(lcs.sam_str[x],lcs.depth[x]-dep,dep);
}
else
{
if(lcs.sam_str[x]) R[++rsc]=Node(lcs.sam_str[x],lcs.depth[x]-dep,dep);
}
for(int i=hd[x];i;i=e[i].nex)
{
int y=e[i].to;
if(vis[i]||y==fa) continue;
get_node(y,x,dep+e[i].w,ty);
}
}
void Divide_And_conquer(int x)
{
if(totsize==1) return;
mx=inf;
rt1=rt2=ed=0;
find_edge(x,0);
vis[ed]=vis[ed^1]=1;
lsc=rsc=0;
get_node(rt1,0,0,1);
get_node(rt2,0,0,2);
W=(ull)e[ed].w;
ull tmp=ans;
for(int i=1;i<=lsc;++i) vir::_new(lcp.str_sam[L[i].u],L[i].dis,1);
for(int i=1;i<=rsc;++i) vir::_new(lcp.str_sam[R[i].u],R[i].dis,2);
vir::solve();
int tmprt1=rt1,tmprt2=rt2;
int sizert1=size[rt1],sizert2=totsize-size[rt1];
totsize=sizert1;
Divide_And_conquer(tmprt1);
totsize=sizert2;
Divide_And_conquer(tmprt2);
}
int main()
{
// IO::setIO("input");
int i,j,n;
scanf("%s%d%d",S+1,&K1,&K2);
n=strlen(S+1);
lcp.Initialize();
lcs.Initialize();
for(i=1;i<=n;++i)
{
lcs.extend(S[i]-'a');
lcp.extend(S[n-i+1]-'a');
}
lcs.Build_LCS();
lcp.Build_LCP();
tot=lcs.tot;
Build_Tree(1,0);
vir::Initialize();
totsize=tot;
Divide_And_conquer(1);
printf("%llu\n",ans);
return 0;
}