题意:给出n个字符串,再给出m个询问,每个询问包含一个字符串,问该字符串是n个字符串中多少个字符串的子串。
分析:利用字典树能很好解决若干字符串中某个前缀出现次数的特性,将给出的n个字符串分解插入字典树中。另外,由于这里某个字符串可能有多个拥有相同前缀的子串,因此可以在字典树结点中加入一个信息num,用于表示该子串是第num个字符串的子串。
代码如下:
#include <cstdio>
#include <stack>
#include <set>
#include <iostream>
#include <string>
#include <vector>
#include <queue>
#include <list>
#include <functional>
#include <cstring>
#include <algorithm>
#include <cctype>
#include <string>
#include <map>
#include <iomanip>
#include <cmath>
#define LL long long
#define ULL unsigned long long
#define SZ(x) (int)x.size()
#define Lowbit(x) ((x) & (-x))
#define MP(a, b) make_pair(a, b)
#define MS(arr, num) memset(arr, num, sizeof(arr))
#define PB push_back
#define F first
#define S second
#define ROP freopen("input.txt", "r", stdin);
#define MID(a, b) (a + ((b - a) >> 1))
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define lrt rt << 1
#define rrt rt << 1|1
#define root 1,n,1
#define BitCount(x) __builtin_popcount(x)
#define BitCountll(x) __builtin_popcountll(x)
#define LeftPos(x) 32 - __builtin_clz(x) - 1
#define LeftPosll(x) 64 - __builtin_clzll(x) - 1
const double PI = acos(-1.0);
const LL INF = (((LL)1)<<62)+1;
using namespace std;
const double eps = 1e-5;
const int MAXN = 300 + 10;
const int MOD = 1000007;
const double M=1e-8;
const int N=30;
typedef pair<int, int> pii;
typedef pair<int, string> pis;
const int d[4][2]={{0,1},{0,-1},{-1,0},{1,0}};
int n,k,m;
struct node
{
int id,cnt;
node *next[26];
node()
{
id=cnt=0;
MS(next,NULL);
}
};
class trie
{
public:
node *rt;
trie() { rt=new node; }
void insert(char *s,int id)
{
node *p=rt;
int i=0;
while(s[i]) {
if (p->next[s[i]-'a']==NULL) p->next[s[i]-'a']=new node;
p=p->next[s[i]-'a'];
if (p->id!=id) {
p->cnt++;
}
p->id=id;
i++;
}
}
int find(char s[])
{
node *p=rt;
int i=0;
while(s[i]) {
if (p->next[s[i]-'a']==NULL) return 0;
p=p->next[s[i]-'a'];
i++;
}
return p->cnt;
}
void dfs(node *p)
{
for (int i=0;i<26;i++) if (p->next[i]) {
dfs(p->next[i]);
delete p->next[i];
}
}
void del()
{
dfs(rt);
}
};
int main()
{
int i,j;
trie t;
cin>>n;
for (i=1;i<=n;i++) {
char s[N];
scanf("%s",s);
for (j=0;s[j];j++) {
t.insert(s+j,i);
}
}
cin>>n;
for (i=0;i<n;i++) {
char s[N];
scanf("%s",s);
printf("%d\n",t.find(s));
}
t.del();
}