http://codeforces.com/problemset/problem/235/C
陈立杰出的后缀自动机。
题目大意:给一个字符串S,再给一个字符串T,设T的长度为len,问T的循环串在S中出现的次数,这里循环串的定义是:对于一个长度为len的字符串,我们把它首尾相接,然后从任意位置开始走len步所得到的串我们叫做T的循环串。如abaa的循环串有 abaa,baaa,aaab,aaba。(注意如果重复只算一次。比如aaa的循环串只有一个aaa)
思路:对于字符串S,我们构造S的后缀自动机,然后对于每一个字符串T,我们设T'为T去掉最后一个字符所得到的字符串,然后构造TT',在S的后缀自动机上进行匹配,我们可以算出对于TT'的每一个位置,可以匹配的最大总长度,那么当匹配长度大于等于len时(这里的len为T的长度),设当前所在状态为p,则我们可以根据fa链找到第一个匹配长度大于等于len时所对应的状态,设为q,则我们设状态q所表示的子串出现的次数为q->num,则ans += q->num,num的计算还是通过拓扑排序,自底向上求即可,注意这里有可能有重复,所以我们还得在每一个状态里设一个标记flag,表示当前状态是否被计算过,若已计算过则跳过即可。
数组写法(代码量少,空间小,访问速度快)
//155 ms
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <iostream>
#include <string>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <bitset>
#include <stack>
using namespace std;
#define REP(i,n) for ( int i=1; i<=int(n); i++ )
#define MP make_pair
#define PB push_back
#define SZ(x) (int((x).size()))
#define ALL(x) (x).begin(), (x).end()
#define X first
#define Y second
template<typename T> inline bool sonkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool sonkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }
typedef long long LL;
typedef long double LD;
const int INF = 0x3f3f3f3f;
const int N = 1e6 + 10;
vector<int> ans;
bool flag[N << 1];
namespace SAM {
int sz = 0, rt = 0, last = 0;
int son[N << 1][26], fa[N << 1], val[N << 1], sc[N << 1];
void init() {
for(int i = 0; i <= sz; i ++) {
memset(son[i], 0, sizeof(son[i]));
fa[i] = val[i] = sc[i] = 0;
}
sz = 0; rt = ++ sz; last = rt;
}
void add(int c) {
int p = last, np = ++ sz;
last = np; val[np] = val[p] + 1;
sc[np] = 1;
for (; p && !son[p][c]; p = fa[p]) son[p][c] = np;
if (p) {
int q = son[p][c];
if (val[p] + 1 == val[q]) fa[np] = q;
else {
int nq = ++ sz;
memcpy(son[nq], son[q], sizeof(son[q]));
fa[nq] = fa[q], val[nq] = val[p] + 1;
fa[q] = fa[np] = nq;
for (; p && son[p][c] == q; p = fa[p]) son[p][c] = nq;
}
}
else fa[np] = rt;
}
void getRight(char *s, int n) {
static int Q[N << 1];
static int cnt[N];
for (int i = 0; i <= n; ++ i) cnt[i] = 0;
for (int p = rt; p <= sz; ++ p) cnt[val[p]] ++;
for (int i = 1; i <= n; ++ i) cnt[i] += cnt[i - 1];
for (int p = rt; p <= sz; ++ p) Q[-- cnt[val[p]]] = p;
for (int i = sz - 1; i >= 0; -- i) {
int p = Q[i]; if (fa[p]) sc[fa[p]] += sc[p];
}
}
void build(char *s, int n) {
init();
for (int i = 0; i < n; ++ i) add(s[i] - 'a');
getRight(s, n);
}
int solve(char *s, int n) {
int res = 0, len = (n + 1) / 2;
int p = rt;
int matson_len = 0;
for(int i = 0; i < n; i ++) {
int c = s[i] - 'a';
while(p && !son[p][c]) p = fa[p];
if(p) {
matson_len = min(matson_len, val[p]) + 1;
p = son[p][c];
} else p = rt, matson_len = 0;
if(matson_len >= len) {
while(val[fa[p]] >= len) {
p = fa[p];
matson_len = min(matson_len, val[p]);
}
if(flag[p]) continue;
flag[p] = true;
ans.push_back(p);
}
}
for(int i = 0; i < ans.size(); i ++)
flag[ans[i]] = 0, res += sc[ans[i]];
ans.clear();
return res;
}
}
char str[N];
int main() {
scanf("%s", str);
SAM::build(str, strlen(str));
int m;
scanf("%d", &m);
while(m --) {
scanf("%s", str);
int len = strlen(str);
for(int i = 0; i < len - 1; i ++) str[len + i] = str[i];
printf("%d\n", SAM::solve(str, 2 * len - 1));
}
}
结构体写法(结构清晰,代码量长,访问较慢)
//280 ms
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <iostream>
#include <string>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <bitset>
#include <stack>
using namespace std;
#define REP(i,n) for ( int i=1; i<=int(n); i++ )
#define MP make_pair
#define PB push_back
#define SZ(x) (int((x).size()))
#define ALL(x) (x).begin(), (x).end()
#define X first
#define Y second
template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }
typedef long long LL;
typedef long double LD;
const int INF = 0x3f3f3f3f;
typedef pair<int, int> pii;
const int N = 1e6 + 10;
struct Node {
int ch[26], fa, val, sc;
Node(): fa(0), val(0), sc(0) {
memset(ch, 0, sizeof(ch));
}
void clear() {
memset(ch, 0, sizeof(ch));
fa = 0; val = sc = 0;
}
} pool[N << 1];
vector<int> ans;
bool flag[N << 1];
namespace SAM {
int sz = 0, rt = 0, last = 0;
void init() {
for(int i = 0; i <= sz; i ++) pool[i].clear();
sz = 0; rt = ++ sz; last = rt;
}
void add(int c) {
int p = last, np = ++ sz;
last = np; pool[np].val = pool[p].val + 1;
pool[np].sc = 1;
for (; p && !pool[p].ch[c]; p = pool[p].fa) pool[p].ch[c] = np;
if (p) {
int q = pool[p].ch[c];
if (pool[p].val + 1 == pool[q].val) pool[np].fa = q;
else {
int nq = ++ sz;
memcpy(&pool[nq], &pool[q], sizeof(pool[q]));
pool[nq].val = pool[p].val + 1, pool[nq].sc = 0;
pool[q].fa = nq; pool[np].fa = nq;
for (; p && pool[p].ch[c] == q; p = pool[p].fa) pool[p].ch[c] = nq;
}
}
else pool[np].fa = rt;
}
void getRight(char *s, int n) {
static int Q[N << 1];
static int cnt[N];
for (int i = 0; i <= n; ++ i) cnt[i] = 0;
for (int p = rt; p <= sz; ++ p) cnt[pool[p].val] ++;
for (int i = 1; i <= n; ++ i) cnt[i] += cnt[i - 1];
for (int p = rt; p <= sz; ++ p) Q[-- cnt[pool[p].val]] = p;
for (int i = sz - 1; i >= 0; -- i) {
int p = Q[i]; if (pool[p].fa) pool[pool[p].fa].sc += pool[p].sc;
}
}
void build(char *s, int n) {
init();
for (int i = 0; i < n; ++ i) add(s[i] - 'a');
getRight(s, n);
}
int solve(char *s, int n) {
int res = 0, len = (n + 1) / 2;
int p = rt;
int match_len = 0;
for(int i = 0; i < n; i ++) {
int c = s[i] - 'a';
while(p && !pool[p].ch[c]) p = pool[p].fa;
if(p) {
match_len = min(match_len, pool[p].val) + 1;
p = pool[p].ch[c];
} else p = rt, match_len = 0;
if(match_len >= len) {
while(pool[pool[p].fa].val >= len) {
p = pool[p].fa;
match_len = min(match_len, pool[p].val);
}
if(flag[p]) continue;
flag[p] = true;
ans.push_back(p);
}
}
for(int i = 0; i < ans.size(); i ++)
flag[ans[i]] = 0, res += pool[ans[i]].sc;
ans.clear();
return res;
}
}
char str[N];
int main() {
scanf("%s", str);
SAM::build(str, strlen(str));
int m;
scanf("%d", &m);
while(m --) {
scanf("%s", str);
int len = strlen(str);
for(int i = 0; i < len - 1; i ++) str[len + i] = str[i];
printf("%d\n", SAM::solve(str, 2 * len - 1));
}
}
结构体指针写法(结构清晰,代码量小,访问慢,指针空间消耗大)
//280 ms
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <iostream>
#include <string>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <bitset>
#include <stack>
using namespace std;
#define REP(i,n) for ( int i=1; i<=int(n); i++ )
#define MP make_pair
#define PB push_back
#define SZ(x) (int((x).size()))
#define ALL(x) (x).begin(), (x).end()
#define X first
#define Y second
template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }
typedef long long LL;
typedef long double LD;
const int INF = 0x3f3f3f3f;
typedef pair<int, int> pii;
const int N = 1e6 + 10;
struct Node {
Node *ch[26], *fa;
int val, sc;
Node(): fa(0), val(0), sc(0) {
memset(ch, 0, sizeof(ch));
}
void clear() {
memset(ch, 0, sizeof(ch));
fa = 0; val = sc = 0;
}
} pool[N << 1], *rt, *last;
vector<Node *> ans;
bool flag[N << 1];
namespace SAM {
Node *sz = pool;
void init() {
if (sz != pool) {
for (Node *p = pool; p < sz; ++ p) p->clear();
}
sz = pool; rt = sz ++; last = rt;
}
void add(int c) {
Node *p = last, *np = sz ++;
last = np; np->val = p->val + 1;
np->sc = 1;
for (; p && !p->ch[c]; p = p->fa) p->ch[c] = np;
if (p) {
Node *q = p->ch[c];
if (p->val + 1 == q->val) np->fa = q;
else {
Node *nq = sz ++; *nq = *q;
nq->sc = 0;
nq->val = p->val + 1;
q->fa = nq; np->fa = nq;
for (; p && p->ch[c] == q; p = p->fa) p->ch[c] = nq;
}
}
else np->fa = rt;
}
void getRight(char *s, int n) {
static Node* Q[N << 1];
static int cnt[N];
for (int i = 0; i <= n; ++ i) cnt[i] = 0;
for (Node *p = pool; p < sz; ++ p) cnt[p->val] ++;
for (int i = 1; i <= n; ++ i) cnt[i] += cnt[i - 1];
for (Node *p = pool; p < sz; ++ p) Q[-- cnt[p->val]] = p;
for (int i = (sz - pool) - 1; i >= 0; -- i) {
Node *p = Q[i]; if (p->fa) p->fa->sc += p->sc;
}
}
void build(char *s, int n) {
init();
for (int i = 0; i < n; ++ i) add(s[i] - 'a');
getRight(s, n);
}
int solve(char *s, int n) {
int res = 0, len = (n + 1) / 2;
Node *p = rt;
int match_len = 0;
for(int i = 0; i < n; i ++) {
int c = s[i] - 'a';
while(p && !p->ch[c]) p = p->fa;
if(p) {
match_len = min(match_len, p->val) + 1;
p = p->ch[c];
} else p = rt, match_len = 0;
if(match_len >= len) {
while(p != rt && p->fa->val >= len) {
p = p->fa;
match_len = min(match_len, p->val);
}
if(flag[p - pool]) continue;
flag[p - pool] = true;
ans.push_back(p);
}
}
for(int i = 0; i < ans.size(); i ++)
flag[ans[i] - pool] = 0, res += ans[i]->sc;
ans.clear();
return res;
}
}
char str[N];
int main() {
scanf("%s", str);
SAM::build(str, strlen(str));
int m;
scanf("%d", &m);
while(m --) {
scanf("%s", str);
int len = strlen(str);
for(int i = 0; i < len - 1; i ++) str[len + i] = str[i];
printf("%d\n", SAM::solve(str, 2 * len - 1));
}
}