Description
字符串树本质上还是一棵树,即N个节点N-1条边的连通无向无环图,节点从1到N编号。与普通的树不同的是,树上的每条边都对应了一个字符串。每次给出一个字符串S和两个节点U,V,需要回答U和V之间上有多少个字符串以S为前缀。
Sample Input
4
1 2 ab
2 4 ac
1 3 bc
3
1 4 a
3 4 b
3 2 ab
Sample Output
2
1
1
这道题由于它字符串长度小于10,那你就对它的所有前缀的hash暴力建树,然后像COT那样做就好了,不过我的代码好像不太优秀,OJ时间复杂度倒数Rank1。。。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef unsigned long long ULL;
int read() {
int x = 0, f = 1;char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
return x * f;
}
const int hs = 29;
const ULL inf = 18446744073709551615UL;
struct edge {
int x, y, next, o;
ULL ss[11];
} e[210000]; int len, last[110000];
struct node {
int lc, rc, c;
} t[31000000]; int cnt, rt[210000];
ULL A[11]; int fa[110000][21], dep[110000];
char ss[11];
void ins(int x, int y, int kk) {
e[++len].x = x; e[len].y = y; e[len].o = kk;
for(int i = 1; i <= kk; i++) e[len].ss[i] = A[i];
e[len].next = last[x]; last[x] = len;
}
void Link(int &u, ULL l, ULL r, ULL p) {
if(!u) u = ++cnt;
t[u].c++;
if(l == r) return ;
ULL mid = l / 2 + r / 2;
if(p <= mid) Link(t[u].lc, l, mid, p);
else Link(t[u].rc, mid + 1, r, p);
}
void Merge(int &u1, int u2) {
if(!u1 || !u2) {u1 = u1 + u2; return ;}
t[u1].c += t[u2].c;
Merge(t[u1].lc, t[u2].lc);
Merge(t[u1].rc, t[u2].rc);
}
int query(int u1, int u2, int u3, ULL l, ULL r, ULL k) {
if(l == r) return t[u1].c + t[u2].c - t[u3].c * 2;
ULL mid = l / 2 + r / 2;
if(k <= mid) return query(t[u1].lc, t[u2].lc, t[u3].lc, l, mid, k);
else return query(t[u1].rc, t[u2].rc, t[u3].rc, mid + 1, r, k);
}
void dfs(int x) {
for(int i = 1; i <= 18; i++) fa[x][i] = fa[fa[x][i - 1]][i - 1];
for(int k = last[x]; k; k = e[k].next) {
int y = e[k].y;
if(y != fa[x][0]) {
for(int i = 1; i <= e[k].o; i++) Link(rt[y], 0, inf, e[k].ss[i]);
Merge(rt[y], rt[x]);
fa[y][0] = x; dep[y] = dep[x] + 1;
dfs(y);
}
}
}
int LCA(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = 18; i >= 0; i--) {
if(dep[x] - dep[y] >= (1 << i)) {
x = fa[x][i];
}
}
if(x == y) return x;
for(int i = 18; i >= 0; i--) {
if(fa[x][i] != fa[y][i]) {
x = fa[x][i], y = fa[y][i];
}
}
return fa[x][0];
}
int main() {
int n = read();
for(int i = 1; i < n; i++) {
int x = read(), y = read();
scanf("%s", ss + 1); int u = strlen(ss + 1);
for(int i = 1; i <= u; i++) A[i] = A[i - 1] * hs + ss[i] - 'a' + 1;
ins(x, y, u); ins(y, x, u);
}
dfs(1);
int m = read();
for(int i = 1; i <= m; i++) {
int x = read(), y = read(); scanf("%s", ss + 1);
ULL cc = 0;
for(int i = 1; i <= strlen(ss + 1); i++) cc = cc * hs + ss[i] - 'a' + 1;
int lca = LCA(x, y);
printf("%d\n", query(rt[x], rt[y], rt[lca], 0, inf, cc));
}
return 0;
}