题目链接;
https://ac.nowcoder.com/acm/contest/884/I
题意:
输入一个字符串,判断这个字符串最多有多少个不同的子串,子串不同要求:子串a 不等于 子串b ,子串a 不等于 子串b的反转串。
字符串长度小于2e5
样例:
输入:
abac
输出:
8
说明:
The set of following substrings is such a choice: abac,b,a,ab,aba,bac,ac,c.
题解:
后缀数组中的height数组之和就是字符串中有多少个重复的子串,字符串长度为n,就有n*(n+1)/2个子串,
用后缀数组求出字符串s和字符串s的反转串 总共有多少不同的子串,加上字符串s中不同回文串的个数。除以2就是答案。
求字符串中有多少不同的回文串是回文树板子题
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=4e5+5;
const int ALP=26;
int sa[maxn],Rank[maxn],rank2[maxn],height[maxn],cnt[maxn],*x,*y,mxx[maxn];
char s[maxn];
struct Palindromic_Tree {
int son[maxn][ALP]; //转移边
int fail[maxn]; //fail 指针
int cnt[maxn]; //当前节点表示的回文串在原串中出现了多少次
int num[maxn]; //当前节点 fail 可以向前跳多少次
int len[maxn]; //当前节点表示的回文串的长度
int S[maxn]; //插入的字符串
int last; //最后一次访问到的节点,类似 SAM
int n; //插入的字符串长度
int p; //自动机的总状态数
int newnode(int l) {
memset(son[p], 0, sizeof(son[p]));
cnt[p] = 0;
num[p] = 0;
len[p] = l;
return p++;
}
void init() {
p = 0;
newnode(0);
newnode(-1);
last = 0;
n = 0;
S[n] = -1;
fail[0] = 1;
}
int get_fail(int x) {
while (S[n - len[x] - 1] != S[n]) x = fail[x];
return x;
}
void add(int c) {
c -= 'a';
S[++n] = c;
int cur = get_fail(last); //通过上一次访问的位置去扩展
if (!son[cur][c]) { //如果没有对应的节点添加一个新节点
int now = newnode(len[cur] + 2);
fail[now] = son[get_fail(fail[cur])][c]; //通过当前节点的 fail 去扩展出新的 fail
son[cur][c] = now;
num[now] = num[fail[now]] + 1; //记录 fail 跳多少次
}
last = son[cur][c];
cnt[last]++; //表示当前节点访问了一次
}
void count() {
//如果某个节点出现一次,那么他的 fail 也一定会出现一次,并且在插入的时候没有计数
for (int i = p - 1; i >= 0; i--) cnt[fail[i]] += cnt[i];
}
}a;
void radix_sort(int n,int m){
memset(cnt,0,sizeof(cnt));
for(int i=0;i<n;i++) cnt[x[y[i]]]++;
for(int i=1;i<m;i++) cnt[i]+=cnt[i-1];
for(int i=n-1;i>=0;i--) sa[--cnt[x[y[i]]]]=y[i];
}
void get_sa(char s[],int n){
int m=128;
x=Rank,y=rank2;
for(int i=0;i<n;i++) x[i]=s[i],y[i]=i;
radix_sort(n,m);
for(int len=1;len<n;len<<=1){
int p=0;
for(int i=n-len;i<n;i++) y[p++]=i;
for(int i=0; i<n; i++) if(sa[i]>=len) y[p++]=sa[i]-len;
radix_sort(n,m);
swap(x,y);
x[sa[0]]=p=0;
for(int i=1;i<n;i++){
if(y[sa[i-1]]==y[sa[i]]&&sa[i-1]+len<n&&sa[i]+len<n&&y[sa[i-1]+len]==y[sa[i]+len])
x[sa[i]]=p;
else
x[sa[i]]=++p;
}
m=p+1;
if(m>=n) break;
}
for(int i=0;i<n;i++) Rank[i]=x[i];
}
void get_height(char s[],int n){
int k=0;
for(int i=0;i<n;i++){
if(Rank[i] == 0) continue;
k=max(0,k-1);
int j=sa[Rank[i]-1];
while(i+k<n&&j+k<n&&s[i+k]==s[j+k]) k++;
height[Rank[i]] = k;
}
}
int main(){
scanf("%s",s);
ll n=strlen(s);
a.init();
for(int i=0;i<n;i++) a.add(s[i]);
ll cnt=a.p-2;
s[n]='@';
for(int i=0;i<n;i++) s[i+n+1]=s[n-i-1];
ll x=n;
n=2*n+1;
s[n]=0;
get_sa(s,n);
get_height(s,n);
ll y=0;
for(int i=0;i<n;i++) y+=height[i];
ll ans=(n*(n+1)/2-y-(x+1)*(x+1)+cnt)/2;
printf("%lld\n",ans);
return 0;
}