题目描述:
给出一个字符串
求该字符串本质不同且不互为反串的子串的数目。
abac
abac,b,a,ab,aba,bac,ac,c
ab和ba互为反串只算一个
这个题目是SAM+PAM的板子题
SAM有两种使用的方法。
一:
输入原串然后输入一个特殊字符,然后输入原串反串。
这种我们在计算时候
a
n
s
=
(
s
a
m
.
s
u
m
−
1
L
L
∗
(
l
e
n
+
1
)
∗
(
l
e
n
+
1
)
+
p
)
/
2
ans=(sam.sum-1LL*(len+1)*(len+1)+p)/2
ans=(sam.sum−1LL∗(len+1)∗(len+1)+p)/2
二:
输入原串,sam.last置为1,输入原串的反串。
a
n
s
=
(
s
a
m
.
s
u
m
+
p
)
/
2
ans=(sam.sum+p)/2
ans=(sam.sum+p)/2
PS:
sam.sum为sam中本质不同的子串的数目
(len+1)*(len+1)是特殊字符贡献的子串的数目。
p为本质不同的回文子串的数目
详见代码及注释。
#include<bits/stdc++.h>
using namespace std;
namespace SAM {
const int N_CHAR = 27;
const int maxn = 2e5 + 50;
struct Node {
int nxt[N_CHAR], fail;
int len; // Max Length of State
int pos; // Appear Position of State, Indexed From 1
int cnt; // Appear Count of State
}node[maxn * 4];
int numn, last, root;
inline int newNode(int l, int p) {
int x = ++numn;
for (int i = 0; i < N_CHAR; i++) node[x].nxt[i] = 0;
node[x].cnt = node[x].fail = 0;
node[x].len = l;
node[x].pos = p;
return x;
}
inline void init() {
root = last = newNode(numn = 0, 0);
}
inline void addChar(int c) {
int p = last, np = newNode(node[p].len + 1, node[p].len + 1);
while (p && node[p].nxt[c] == 0) node[p].nxt[c] = np, p = node[p].fail;
if (p == 0) node[np].fail = root; else {
int q = node[p].nxt[c];
if (node[p].len + 1 == node[q].len) {
node[np].fail = q;
}
else {
int nq = newNode(node[p].len + 1, node[q].pos);
for (int i = 0; i < N_CHAR; i++) node[nq].nxt[i] = node[q].nxt[i];
node[nq].fail = node[q].fail;
node[q].fail = node[np].fail = nq;
while (p && node[p].nxt[c] == q) node[p].nxt[c] = nq, p = node[p].fail;
}
}
last = np; node[np].cnt = 1;
}
}
using namespace SAM;
struct Pam
{
int nxt[maxn][30], fail[maxn], len[maxn], s[maxn], last, n, p;
inline int newnode(int l) {
memset(nxt[p], 0, sizeof(nxt[p]));
len[p] = l;
return p++;
}
void init() {
p = 0;
newnode(0), newnode(-1);
last = n = 0;
s[0] = -1;
fail[0] = 1;
}
int getfail(int x)
{
while (s[n - len[x] - 1] != s[n])
x = fail[x];
return x;
}
void add(int c)
{
s[++n] = c;
int cur = getfail(last);
if (!nxt[cur][c]) {
int now = newnode(len[cur] + 2);
fail[now] = nxt[getfail(fail[cur])][c];
nxt[cur][c] = now;
}
last = nxt[cur][c];
}
}pam;
int main()
{
char s[maxn];
scanf("%s", s);
init();
int len = strlen(s);
for (int i = 0; i < len; i++) addChar(s[i] - 'a');
addChar(26);//方法2把这句改为last=1;
for (int i = len - 1; i >= 0; i--)addChar(s[i] - 'a');
long long ans = 0;
for (int i = numn; i >= 0; i--)
{
ans = ans + node[i].len - (node[node[i].fail].len);
}
pam.init();
for (int i = 0; i < len; i++) pam.add(s[i] - 'a');
int p = pam.p - 2;
ans = ans - 1LL*(len + 1)*(len + 1) + p;//方法二改为ans=ans+p;
printf("%lld\n", ans / 2);
return 0;
}