好久没写 AC
自动机了,来一发套线段树的题解吧。
题目大意
给定包含 k k k 个字符串的集合 S S S,有 n n n 个操作,操作有三种类型:
-
以
?
开头的操作为询问操作,询问当前字符串集S中的每一个字符串匹配询问字符串的次数之和; -
以
+
开头的操作为添加操作,表示将编号为 i i i 的字符串加入到集合中; -
以
-
开头的操作为删除操作,表示将编号为 i i i 的字符串从集合中删除。
注意当编号为i的字符串已经在集合中时,允许存在添加编号为 i i i 的字符串,删除亦然。
解题思路
读题,发现题目是要求多模式串匹配数总和。
需要用到一个结论:这个模式串一定会被这个模式串末尾位置在 fail
树上的子树代表的字符串匹配到。
由于所有字符串都已经给出,那么先建好 fail
树,标记上每个串的结尾。
故添加一个字符串,只需要在这个串的结尾在 fail
树上的子树标记都加
1
1
1,删除则减
1
1
1。
标记用线段树维护即可。
对于询问,将他放在 fail
树上跑,并累加标记即可。
具体请看代码。
CODE
#include <bits/stdc++.h>
using namespace std;
inline int read()
{
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
{
x = x * 10 + c - '0';
c = getchar();
}
return x * f;
}
int n, m;
char str[1000007];
int cnt, id[100007], tr[1000007][26], fail[1000007];
int tot, head[1000007], to[1000007], nxt[1000007];
int cnt_node, L[1000007], R[1000007];
int del[100007];
void add(int u, int v)
{
to[++tot] = v;
nxt[tot] = head[u];
head[u] = tot;
}
void insert(int x)
{
scanf("%s", str);
int len = strlen(str);
int p = 0;
for (int i = 0; i < len; ++i)
{
int v = str[i] - 'a';
if (!tr[p][v])
tr[p][v] = ++cnt;
p = tr[p][v];
}
id[x] = p;
}
void get_fail()
{
queue<int> q;
for (int i = 0; i < 26; ++i)
if (tr[0][i])
{
fail[tr[0][i]] = 0;
q.push(tr[0][i]);
}
while (!q.empty())
{
int u = q.front();
q.pop();
for (int i = 0; i < 26; ++i)
{
if (tr[u][i])
{
fail[tr[u][i]] = tr[fail[u]][i];
q.push(tr[u][i]);
}
else
{
tr[u][i] = tr[fail[u]][i];
}
}
}
for (int i = 1; i <= cnt; ++i)
add(fail[i], i);
}
void dfs(int u, int fa)
{
L[u] = ++cnt_node;
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (v == fa)
continue;
dfs(v, u);
}
R[u] = cnt_node;
}
int t[1000007 << 2], lazy[1000007 << 2];
void push_down(int o, int l, int r)
{
if (lazy[o])
{
int mid = (l + r) >> 1;
t[o << 1] += (mid - l + 1) * lazy[o];
t[o << 1 | 1] += (r - mid) * lazy[o];
lazy[o << 1] += lazy[o];
lazy[o << 1 | 1] += lazy[o];
lazy[o] = 0;
}
}
void update(int o, int l, int r, int L, int R, int k)
{
if (L <= l && r <= R)
{
t[o] += (r - l + 1) * k;
lazy[o] += k;
return;
}
int mid = (l + r) >> 1;
push_down(o, l, r);
if (L <= mid)
update(o << 1, l, mid, L, R, k);
if (R > mid)
update(o << 1 | 1, mid + 1, r, L, R, k);
t[o] = t[o << 1] + t[o << 1 | 1];
}
int query(int o, int l, int r, int L, int R)
{
if (L <= l && r <= R)
{
return t[o];
}
int mid = (l + r) >> 1;
int res = 0;
push_down(o, l, r);
if (L <= mid)
res += query(o << 1, l, mid, L, R);
if (R > mid)
res += query(o << 1 | 1, mid + 1, r, L, R);
t[o] = t[o << 1] + t[o << 1 | 1];
return res;
}
int Query()
{
int x = 0, i = 0, ans = 0;
for (int i = 1, len = strlen(str); i < len; ++i)
x = tr[x][str[i] - 'a'], ans += query(1, 1, cnt, 1, L[id[x]]);
return ans;
}
signed main()
{
m = read(), n = read();
for (int i = 1; i <= n; ++i)
insert(i);
get_fail();
dfs(0, -1);
cnt++;
for (int i = 1; i <= n; ++i)
update(1, 1, cnt, L[id[i]], R[id[i]], 1);
while (m--)
{
char opt;
int x = 0;
cin >> opt;
if (opt == '?')
{
int ans = 0;
scanf("%s", str);
// printf("%s\n", str);
for (int i = 0, len = strlen(str); i < len; ++i)
{
x = tr[x][str[i] - 'a'];
if (x)
{
ans += query(1, 1, cnt, L[x], L[x]);
}
}
printf("%d\n", ans);
continue;
}
if (opt == '-')
{
cin >> x;
if (!del[x])
del[x] = 1, update(1, 1, cnt, L[id[x]], R[id[x]], -1);
}
else
{
cin >> x;
if (del[x])
del[x] = 0, update(1, 1, cnt, L[id[x]], R[id[x]], 1);
}
}
}