对于a串可能和b串重复的部分,我们总能找到一个位置将该串分成前半段和后半段,同时使得前半段属于a串的部分达到最长,即在该串划分的位置后面的那个字母在属于a串的部分的后继为空(也就是说不能找到更长的前半段划分方式了),所以我们只要预处理以字符x为开头的b串的个数即可。
然后在A的每一个子串枚举算后面有没有以字符x为开的后继。
新建图
// whn6325689
// Mr.Phoebe
// http://blog.csdn.net/u013007900
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>
#include <functional>
#include <numeric>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
#define eps 1e-9
#define PI acos(-1.0)
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62
#define speed std::ios::sync_with_stdio(false);
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;
#define CLR(x,y) memset(x,y,sizeof(x))
#define CPY(x,y) memcpy(x,y,sizeof(x))
#define clr(a,x,size) memset(a,x,sizeof(a[0])*(size))
#define cpy(a,x,size) memcpy(a,x,sizeof(a[0])*(size))
#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))
#define MID(x,y) (x+((y-x)>>1))
#define ls (idx<<1)
#define rs (idx<<1|1)
#define lson ls,l,mid
#define rson rs,mid+1,r
#define root 1,1,n
template<class T>
inline bool read(T &n)
{
T x = 0, tmp = 1;
char c = getchar();
while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
if(c == EOF) return false;
if(c == '-') c = getchar(), tmp = -1;
while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
n = x*tmp;
return true;
}
template <class T>
inline void write(T n)
{
if(n < 0)
{
putchar('-');
n = -n;
}
int len = 0,data[20];
while(n)
{
data[len++] = n%10;
n /= 10;
}
if(!len) data[len++] = 0;
while(len--) putchar(data[len]+48);
}
//-----------------------------------
const int MAXN=90010;
const int MAXC=27;
struct SAM
{
int len[MAXN<<1],next[MAXN<<1][MAXC],fa[MAXN<<1],L,last;
SAM()
{
init();
}
void init()
{
L=0;
last=newnode(0,-1);
}
int newnode(int l,int pre)
{
fa[L]=pre;
for(int i=0; i<MAXC; i++) next[L][i]=-1;
len[L]=l;
return L++;
}
void build(char *p)
{
int le=strlen(p);
for(int i=0; i<le; i++)
add(p[i]-'a',i);
}
void add(int x,int l)
{
int pre=last,now=newnode(len[pre]+1,-1);
last=now;
while(~pre && next[pre][x]==-1)
{
next[pre][x]=now;
pre=fa[pre];
}
if(pre==-1)
{
fa[now]=0;
}
else
{
int bro=next[pre][x];
if(len[bro]==len[pre]+1)
{
fa[now]=bro;
}
else
{
int fail=newnode(len[pre]+1,fa[bro]);
for(int i=0; i<MAXC; i++)next[fail][i]=next[bro][i];
fa[bro]=fail,fa[now]=fail;
while(~pre&&next[pre][x]==bro)
{
next[pre][x]=fail;
pre=fa[pre];
}
}
}
}
} A,B;
char stra[MAXN],strb[MAXN];
int lena,lenb;
vi G[MAXN<<1];
ll dp[MAXN<<1],num[MAXN<<1];
ull len[27];
bool vis[MAXN<<1];
ll rdfs(int u)
{
if(~dp[u]) return dp[u];
dp[u]=0;
for(int i=G[u].size()-1; i>=0; i--)
dp[u]+=rdfs(G[u][i]);
return dp[u];
}
ll dfs(int u)
{
if(~num[u]) return num[u];
num[u]=1;
for(int i=0; i<26; i++)
if(~B.next[u][i])
num[u]+=dfs(B.next[u][i]);
return num[u];
}
void getans(int u)
{
if(vis[u])return ;
vis[u]=true;
bool flag=false;
for(int i=0; i<26; i++)
{
if(~A.next[u][i])
getans(A.next[u][i]);
else if(dp[u])
{
len[i]+=dp[u];
if(!flag)
{
flag=true;
len[26]+=dp[u];
}
}
}
}
int main()
{
//freopen("data.txt","r",stdin);
int T;
read(T);
while(T--)
{
CLR(dp,-1);
CLR(num,-1);
CLR(len,0);
CLR(G,0);
CLR(vis,0);
scanf("%s %s",stra,strb);
lena=strlen(stra);
lenb=strlen(strb);
A.init();
B.init();
A.build(stra);
B.build(strb);
for(int i=0; i<A.L; i++)
for(int j=0; j<26; j++)
if(~A.next[i][j])
G[A.next[i][j]].pb(i);
dp[0]=1;
for(int i=0; i<A.L; i++) if(dp[i]==-1)
rdfs(i);
dfs(0);
getans(0);
ull ans=0;
for(int i=0; i<26; i++)
if(~B.next[0][i])
ans+=num[B.next[0][i]]*len[i];
ans+=len[26];
write(ans),putchar('\n');
}
return 0;
}
拓扑排列
// whn6325689
// Mr.Phoebe
// http://blog.csdn.net/u013007900
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <climits>
#include <complex>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <bitset>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <ctime>
#include <set>
#include <map>
#include <cmath>
#include <functional>
#include <numeric>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
#define eps 1e-9
#define PI acos(-1.0)
#define INF 0x3f3f3f3f
#define LLINF 1LL<<62
#define speed std::ios::sync_with_stdio(false);
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef complex<ld> point;
typedef pair<int, int> pii;
typedef pair<pii, int> piii;
typedef vector<int> vi;
#define CLR(x,y) memset(x,y,sizeof(x))
#define CPY(x,y) memcpy(x,y,sizeof(x))
#define clr(a,x,size) memset(a,x,sizeof(a[0])*(size))
#define cpy(a,x,size) memcpy(a,x,sizeof(a[0])*(size))
#define mp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define lowbit(x) (x&(-x))
#define MID(x,y) (x+((y-x)>>1))
#define ls (idx<<1)
#define rs (idx<<1|1)
#define lson ls,l,mid
#define rson rs,mid+1,r
#define root 1,1,n
template<class T>
inline bool read(T &n)
{
T x = 0, tmp = 1;
char c = getchar();
while((c < '0' || c > '9') && c != '-' && c != EOF) c = getchar();
if(c == EOF) return false;
if(c == '-') c = getchar(), tmp = -1;
while(c >= '0' && c <= '9') x *= 10, x += (c - '0'),c = getchar();
n = x*tmp;
return true;
}
template <class T>
inline void write(T n)
{
if(n < 0)
{
putchar('-');
n = -n;
}
int len = 0,data[20];
while(n)
{
data[len++] = n%10;
n /= 10;
}
if(!len) data[len++] = 0;
while(len--) putchar(data[len]+48);
}
//-----------------------------------
const int MAXN=90010;
const int MAXC=27;
struct SAM
{
int len[MAXN<<1],next[MAXN<<1][MAXC],fa[MAXN<<1],L,last;
int num[MAXN<<1];
SAM()
{
init();
}
void init()
{
L=0;
last=newnode(0,-1);
}
int newnode(int l,int pre)
{
fa[L]=pre;
for(int i=0; i<MAXC; i++) next[L][i]=-1;
len[L]=l;num[L]=1;
return L++;
}
void build(char *p)
{
int le=strlen(p);
for(int i=0; i<le; i++)
add(p[i]-'a',i);
}
void add(int x,int l)
{
int pre=last,now=newnode(len[pre]+1,-1);
last=now;
while(~pre && next[pre][x]==-1)
{
next[pre][x]=now;
pre=fa[pre];
}
if(pre==-1)
fa[now]=0;
else
{
int bro=next[pre][x];
if(len[bro]==len[pre]+1)
fa[now]=bro;
else
{
int fail=newnode(len[pre]+1,fa[bro]);
for(int i=0; i<MAXC; i++)next[fail][i]=next[bro][i];
fa[bro]=fail,fa[now]=fail;
while(~pre&&next[pre][x]==bro)
{
next[pre][x]=fail;
pre=fa[pre];
}
}
}
}
int topxu[MAXN<<1],sum[MAXN<<1];
void topsort()
{
CLR(sum,0);
for(int i=0; i<L; i++) sum[len[i]]++;
for(int i=1; i<L; i++) sum[i]+=sum[i-1];
for(int i=0; i<L; i++) topxu[sum[len[i]]--]=i;
for(int i=L-1; i>=0; i--)
{
int u=topxu[i];
for(int j=0; j<26; j++)
if(~next[u][j])
num[u]+=num[next[u][j]];
}
}
} A,B;
char stra[MAXN],strb[MAXN];
ull len[27];
int main()
{
freopen("data.txt","r",stdin);
int T;
read(T);
while(T--)
{
CLR(len,0);
scanf("%s %s",stra,strb);
A.init();B.init();
A.build(stra);B.build(strb);
B.topsort();
ull ans=1;
for(int i=0; i<26; i++)
if(~B.next[0][i])
len[i]=B.num[B.next[0][i]];
for(int i=0; i<26; i++)
if(A.next[0][i]==-1)
ans+=len[i];
for(int i=1; i<A.L; i++)
{
ull l=A.len[i]-A.len[A.fa[i]];
for(int j=0; j<26; j++)
if(A.next[i][j]==-1)
ans+=l*len[j];
ans+=l;
}
write(ans),putchar('\n');
}
return 0;
}