传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=4477
题意
在一棵树上的每条路都有一个字符串,有q个询问(u,v,S),u,v表示树上的两个节点,S表示一个字符串,求u到v路径上字符串中前缀是S的个数。
题解
将树路径上的字符串压缩到作为子节点的点中,之后用树链剖分维护一个可持久化trie树即可,时间复杂度为O(10qlogn)
code
#include <bits/stdc++.h>
#define N 100010
using namespace std;
char str[N][15];
int n, q, tot, QT;
int head[N];
int size[N], dfn[N], fa[N], top[N], son[N], w[N], ran[N];
struct ss
{
int next, to;
char ch[15];
};
ss Edge[N << 1];
void add(int x, int y, char *ch)
{
Edge[++tot].next = head[x], Edge[tot].to = y;
memcpy(Edge[tot].ch + 1, ch + 1, sizeof(char) * strlen(ch + 1));
head[x] = tot;
}
namespace trie
{
int cnt = 1;
int sum[N * 10], son[N * 10][27], root[N];
int Insert(int x, char *ch)
{
int tmp = ++cnt, y = cnt, len = strlen(ch + 1);
for(int i = 1; i <= len; i++)
{
int p = ch[i] - 'a' + 1;
for(int j = 0; j <= 26; j++)
son[y][j] = son[x][j];
sum[y] = sum[x] + 1, son[y][p] = ++cnt;
x = son[x][p], y = son[y][p];
}
sum[y] = sum[x] + 1;
return tmp;
}
int query(int l, int r, char *ch)
{
int x = root[l - 1], y = root[r], len = strlen(ch + 1);
for(int i = 1; i <= len; i++)
{
int p = ch[i] - 'a' + 1;
x = son[x][p], y = son[y][p];
}
return sum[y] - sum[x];
}
void build()
{
for(int i = 2; i <= n; i++)
root[i] = Insert(root[i - 1], str[ran[i]]);
}
}
void dfs(int u, int rt, int deep)
{
size[u] = 1, fa[u] = rt, dfn[u] = deep;
for(int i = head[u]; i; i = Edge[i].next)
{
int to = Edge[i].to;
if(to == rt) continue;
memcpy(str[to] + 1, Edge[i].ch + 1, sizeof(char) * strlen(Edge[i].ch + 1));
dfs(to, u, deep + 1);
size[u] += size[to];
if(size[to] > size[son[u]]) son[u] = to;
}
}
void dfs(int u, int tp)
{
top[u] = tp, w[u] = ++QT, ran[QT] = u;
if(!son[u]) return ;
dfs(son[u], tp);
for(int i = head[u]; i; i = Edge[i].next)
{
int to = Edge[i].to;
if(to == fa[u] || to == son[u]) continue;
dfs(to, to);
}
}
int ask(int x, int y, char *ch)
{
int ans = 0;
while(top[x] != top[y])
{
if(dfn[top[x]] < dfn[top[y]]) swap(x, y);
ans += trie::query(w[top[x]], w[x], ch);
x = fa[top[x]];
}
if(dfn[x] > dfn[y]) swap(x, y);
ans += trie::query(w[x] + 1, w[y], ch);
return ans;
}
int main()
{
// freopen("tt.in", "r", stdin);
cin >> n;
for(int x, y, i = 1; i < n; i++)
{
char ch[15];
memset(ch, 0, sizeof(ch));
scanf("%d%d%s", &x, &y, ch + 1);
add(x, y, ch), add(y, x, ch);
}
dfs(1, 0, 1), dfs(1, 1);
trie::build();
cin >> q;
for(int x, y, i = 1; i <= q; i++)
{
char ch[15];
memset(ch, 0, sizeof(ch));
scanf("%d%d%s", &x, &y, ch + 1);
printf("%d\n", ask(x, y, ch));
}
return 0;
}