题目地址:
https://www.luogu.com.cn/problem/P8306
题目描述:
给定
n
n
n个模式串
s
1
,
s
2
,
…
,
s
n
s_1, s_2, \dots, s_n
s1,s2,…,sn和
q
q
q次询问,每次询问给定一个文本串
t
i
t_i
ti,请回答
s
1
∼
s
n
s_1 \sim s_n
s1∼sn中有多少个字符串
s
j
s_j
sj满足
t
i
t_i
ti是
s
j
s_j
sj的前缀。一个字符串
t
t
t是
s
s
s的前缀当且仅当从
s
s
s的末尾删去若干个(可以为
0
0
0个)连续的字符后与
t
t
t相同。输入的字符串大小敏感。例如,字符串Fusu
和字符串fusu
不同。
输入格式:
输入的第一行是一个整数,表示数据组数
T
T
T。
对于每组数据,格式如下:
第一行是两个整数,分别表示模式串的个数
n
n
n和询问的个数
q
q
q。
接下来
n
n
n行,每行一个字符串,表示一个模式串。
接下来
q
q
q行,每行一个字符串,表示一次询问。
输出格式:
按照输入的顺序依次输出各测试数据的答案。
对于每次询问,输出一行一个整数表示答案。
数据范围:
对于全部的测试点,保证
1
≤
T
,
n
,
q
≤
1
0
5
1 \leq T, n, q\leq 10^5
1≤T,n,q≤105,且输入字符串的总长度不超过
3
×
1
0
6
3 \times 10^6
3×106。输入的字符串只含大小写字母和数字,且不含空串。
说明:
std的IO使用的是关闭同步后的cin/cout
,本题不卡常。
可以用Trie,每个节点还需要另外存一下经过该节点的字符串的总个数 c c c,这样查询的时候,可以顺着查询字符串向下走,如果走不动了则返回 0 0 0,否则返回最后停留的节点的 c c c值。代码如下:
#include <iostream>
using namespace std;
const int N = 3e6 + 10;
int n, q;
char s[N];
int tr[N][65], idx;
int cnt[N];
int mp['z' + 1];
void add() {
cnt[0]++;
int c = 0;
for (int i = 1; s[i]; i++) {
int pos = mp[s[i]];
if (!tr[c][pos]) tr[c][pos] = ++idx;
c = tr[c][pos];
cnt[c]++;
}
}
int query() {
int c = 0;
for (int i = 1; s[i]; i++) {
int pos = mp[s[i]];
if (!tr[c][pos]) return 0;
c = tr[c][pos];
}
return cnt[c];
}
int main() {
int cc = 0;
for (char ch = 'A'; ch <= 'Z'; ch++) mp[ch] = cc++;
for (char ch = 'a'; ch <= 'z'; ch++) mp[ch] = cc++;
for (char ch = '0'; ch <= '9'; ch++) mp[ch] = cc++;
int T;
scanf("%d", &T);
while (T--) {
for (int i = 0; i <= idx; i++)
for (int j = 0; j <= 'z'; j++)
tr[i][j] = 0;
for (int i = 0; i <= idx; i++)
cnt[i] = 0;
idx = 0;
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++) {
scanf("%s", s + 1);
add();
}
while (q--) {
scanf("%s", s + 1);
printf("%d\n", query());
}
}
}
每组数据时间复杂度 O ( ∑ i s i + ∑ i q i ) O(\sum_i s_i+\sum_iq_i) O(∑isi+∑iqi), s i s_i si是每次插入的字符串长度, q i q_i qi是每次查询的字符串长度,空间 O ( ∑ i s i ) O(\sum_i s_i) O(∑isi)。