题意:
给一个字符串s, 然后再给n个串t1~tn, 问s中有多少个不同的子串, 并且每个子串都不在t1~tn中。
思路:
先搞出s的sam, 然后再用n个字符串在自动机上跑, 求出每个节点的最大匹配的长度, 然后再用每个节点匹配的最大长度去更新par节点的长度。
最后对于最大匹配长度小于val的节点, ans加上val[i] - max(mx[i], val[par[i]])。
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define mnx 100020
#define LL long long
int ch[mnx << 1][26], par[mnx << 1], val[mnx << 1], mx[mnx << 1];
LL cnt[mnx << 1];
int sz, root, lst;
int creat(int _v) {
++sz;
val[sz] = _v;
par[sz] = 0;
mx[sz] = 0;
memset(ch[sz], 0, sizeof(ch[sz]));
return sz;
}
void extend(int c) {
int p = lst;
int np = creat(val[p] + 1);
while(p && ch[p][c] == 0) ch[p][c] = np, p = par[p];
if(!p) par[np] = root;
else {
int q = ch[p][c];
if(val[q] == val[p] + 1) par[np] = q;
else {
int nq = creat(val[p] + 1);
memcpy(ch[nq], ch[q], sizeof ch[q]);
par[nq] = par[q];
par[q] = nq;
par[np] = nq;
while(p && ch[p][c] == q) ch[p][c] = nq, p = par[p];
}
}
lst = np;
}
char s[mnx];
void mark() {
int t = root, len = 0;
for(int i = 0; s[i]; ++i) {
int c = s[i] - 'a';
if(ch[t][c]) {
++len;
t = ch[t][c];
mx[t] = max(mx[t], len);
continue;
}
while(t && ch[t][c] == 0)
t = par[t];
if(!t) t = root, len = 0;
else {
len = val[t] + 1;
t = ch[t][c];
}
mx[t] = max(mx[t], len);
}
}
int b[mnx << 1];
int d[mnx << 1];
int main() {
int cas, kk = 1;
scanf("%d", &cas);
while(cas--) {
int n;
scanf("%d", &n);
scanf("%s", s);
int len = strlen(s);
sz = 0;
root = lst = creat(0);
for(int i = 0; s[i]; ++i) extend(s[i] - 'a');
for(int i = 0; i < n; ++i) {
scanf("%s", s);
mark();
}
memset(d, 0, sizeof(d));
for(int i = 1; i <= sz; ++i) d[val[i]]++;
for(int i = 1; i <= len; ++i) d[i] += d[i - 1];
for(int i = 1; i <= sz; ++i)
b[d[val[i]]--] = i;
for(int i = sz; i >= 1; --i)
if(par[b[i]])
mx[par[b[i]]] = max(mx[par[b[i]]], mx[b[i]]);
LL ans = 0;
for(int i = sz; i >= 1; --i) {
if(val[i] <= mx[i]) continue;
int sub;
if(par[i]) sub = val[par[i]];
else sub = 0;
ans += val[i] - max(sub, mx[i]);
}
printf("Case %d: %I64d\n", kk++, ans);
}
return 0;
}