这道题是codechef上九月月赛的一道压轴题目,题意很简单,就是时时刻刻求一个字符串的所有不同字串个数,支持动态末尾添加字符与前端删除字符功能。
题目链接: http://www.codechef.com/problems/TMP01/
无论是后缀数组还是后缀树或者是后缀自动机,他们可以说都是处理字符串题目的一个非常强大的工具,尤其是在字符串匹配方面可以说已经达到了最高的效率O(n),三个数据结构可以说在这方面都已经做到了尽善尽美。但是由于它们各自存储方式的不同也有各自的优缺点。后缀数组把信息都存储在几个数组当中如sa,rank,h数组等等,由于本人没有深入理解,所以很难有自己的见地。后缀自动机与后缀树我觉得应该比后缀数组更有用,更方便,毕竟他们是二维的东西,以二维的存储结构俯瞰一维的存储结构可以把一些复杂的东西表现的更加自然更加美观。
无论是后缀自动机还是后缀树都可以说是优化的存储了所有子串的trie树,但是由于其O(n * n)级别的时间于空间复杂度,让人难以接受,因此很多前辈高人们就研究出来了后缀树与后缀自动机。后缀自动机一条边代表一个字符,而一个节点却代表了很多的子串(不同于trie树,一个节点代表一个子串),即合并了trie树的具有后缀包含关系的节点因此后缀自动机被证明为O(n)的复杂度。而后缀树是合并了trie树只含有一个孩子的一连串节点,同样也降到了O(n)的复杂度。但是就其变成复杂度而言,无疑后缀自动机更为简单,易懂,这也是我为什么喜欢后缀自动机的原因。然而,无论任何算法都不能完全替代另一种算法,就好比求最短路的spfa算法和dijikstra算法和求最小生成树的prim和kruskal算法一样,每个算法都有其自己的存储方式与思想,对于一些扩展的问题不同方法的实现复杂度可能会有天壤之别,也有可能某些算法根本无法实现。
而这道题,我想出题者就是本着后缀树的思想出的题,因此很难被后缀树组和后缀自动机实现。网上有人说可以离线搞掉,但是时间是O(nlogn)。而用后缀树不但是在线算法, 而且时间复杂度仍是O(n).
算法描述参见: http://discuss.codechef.com/problems/TMP01。
自我感觉实现的方式还是比较neat and beautiful 的!
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
#include <map>
using namespace std;
const int Maxn = 1e6 + 10;
const int Mod =1000000007;
typedef long long ll;
struct Node
{
map<char, int> mp;
int fail, s, t, pa;
int next(char c)
{
if (s == -2) return 1;
map<char, int> :: iterator it = mp.find(c);
if (it == mp.end()) return 0;
return it -> second;
}
void clear(int _s, int _t, int _pa)
{
mp.clear();
s = _s, t = _t, fail = 0, pa = _pa;
}
int len() {return t - s;}
}node[Maxn << 1];
int size, end, ini;
void init()
{
node[0].clear(-2, -1, -1);
node[1].clear(-1, 0, 0);
size = 2, end = 1, ini = 0;
}
queue<int> leaves;
void go(char str[], int t)
{
while (ini <= t)
{
int p = node[end].next(str[ini]);
if (node[p].len() > t - ini) break;
ini += node[p].len();
end = p;
}
}
void add(char str[], int i)
{
char c = str[i];
int last = 1, cur;
for (; ini <= i; end = node[end].fail, go(str, i))
{
cur = end;
if (ini < i)
{
char c1 = str[ini];
int p = node[end].next(c1);
int k = node[p].s + i - ini;
if (str[k] == c) break;
cur = size++;
node[cur].clear(node[p].s, k, end);
node[cur].mp[str[k]] = p;
node[end].mp[c1] = cur;
node[p].pa = cur;
node[p].s = k;
}
else if (node[end].next(c)) break;
int p = size++;
node[p].clear(i, Maxn, cur);
node[cur].mp[c] = p;
leaves.push(p);
if (last != 1) node[last].fail = cur;
last = cur;
}
if (last != 1) node[last].fail = cur;
go(str, i + 1);
}
void delNode(char str[], int cur)
{
char c = node[cur].mp.begin()->first;
int ch = node[cur].next(c);
int pa = node[cur].pa;
node[ch].s -= node[cur].len();
node[ch].pa = pa;
node[pa].mp[str[node[cur].s]] = ch;
if (end == cur)
end = pa, ini -= node[cur].len();
}
int del(char str[], int i)
{
int lea = leaves.front();
int ret = i + 1 - node[lea].s;
leaves.pop();
int pa = node[lea].pa;
int active = end;
if (ini <= i) active = node[end].next(str[ini]);
if (lea != active)
{
node[pa].mp.erase(str[node[lea].s]);
if (node[pa].mp.size() == 1 && pa != 1) delNode(str, pa);
}
else
{
node[lea].s = ini;
end = node[end].fail;
go(str, i + 1);
leaves.push(lea);
ret -= i + 1 - node[lea].s;
}
return ret;
}
char str[Maxn];
int n;
int main()
{
ll ans = 0, sum = 0;
scanf("%d", &n);
init();
int num = 0;
for (int i = 0; i < n; ++i)
{
char op[4];
scanf("%s", op);
if (op[0] == '+')
{
scanf("%s", str + num);
add(str, num++);
ans += leaves.size();
}
else
{
ans -= del(str, num - 1);
}
sum += ans;
sum %= Mod;
}
printf("%lld\n", sum);
}