题目描述
给定一颗 nnn 个节点的无根树,每条边上附有一个小写英文字母。
于是一条路径对应一个字符串。
一共有 qqq 次询问,每次询问以节点 uuu 为起点的非空字符串中有多少字典序严格小于字符串 u⇝vu \leadsto vu⇝v 。
输入格式
第一行,两个个整数 n,qn, qn,q。
接下来 n−1n - 1n−1 行,每行两个整数,一个小写字母。 u,v,cu, v, cu,v,c。 表示存在字母为 ccc 的树边 (u,v)(u, v)(u,v)。保证 u≠vu \neq vu≠v。
接下来 qqq 行,每行两个整数 u,vu, vu,v。
输出格式
qqq 行,每行一个答案。
样例
样例输入 1
4 3
4 1 t
3 2 p
1 2 s
3 2
1 3
2 1
样例输出 1
0
1
1
样例解释 1
第一个询问,以 333 为起点的字符串有p,ps,pst。3⇝23\leadsto 23⇝2 生成p。没有比p字典序严格小的字符串。
第二个询问,以 111 为起点的字符串有s,sp,t。1⇝31\leadsto 31⇝3 生成sp。s字典序比sp小。
第二个询问,以 222 为起点的字符串有p,s,st。2⇝12\leadsto 12⇝1 生成s。p字典序比s小。
样例输入 2
8 4
4 6 p
3 7 o
7 8 p
4 5 d
1 3 o
4 3 p
3 2 e
8 6
3 7
8 1
4 3
样例输出 2
6
1
3
1
数据范围与提示
n≤4000,q≤50000n\le 4000, q \le 50000n≤4000,q≤50000
这道题和普通字典树不一样 字母存在边上而不是点上。。。不过其实没啥差别。
这道题难点在于如何统计严格小于 uv所得字符串的 的字符串数量
其实可以发现一个串比别的严格大的有三种情况:
1.和它长度相等的但是最后字母比它大。
2.比它长度小 但是有个字母比它大
3.比它长度大,但是前面有个字母比它大
所以如何统计呢
首先在每一个节点 建立一棵trie树
然后用sum数组一个个往前统计,因为trie树的性质,所以可以相当于把前面得所有字母都固定住,就是前面每个字母都和该比较串的一模一样。然后不断统计,从叶子节点把数量加上父节点,一层层加上去,最后加到根上,因为如果一层的a 字母比同一层的b字母大,那么就可以加上所有比<=b字母的串的数量,贴代码。然后最后还有统计上前面和它相等串的数量,直接统计tire树的点的cnt即可,因为“abcd”>”abc”,而sum 只能统计出 “abc”>”abb” 和”abc”>”abbc”这种类型的情况 加起来就是所有完整的情况。另外根节点从0开始,根节点是空串,所以加了1,输出答案时减1即可
其实就相当于 sum是从后处理到前面比它小的,而get_answer 是把前面比它小统计到真正的串的后面。因为sum就相当于把前面固定住了,实际上前面也有比它小的情况。。。
#include <bits/stdc++.h>
using namespace std;
struct node
{
int sum[27],son[27],tot,cnt;
}tr[4010];
int n,q;
vector<int> v[4010];
vector<char> ch[4010];
int ans[4010][4010];
int m;
void build(int x,int y,int p)
{
tr[p].cnt++;
for(int i=0;i<v[x].size();i++)
{
if(v[x][i]==y) continue;
int c=ch[x][i]-'a';
if(tr[p].son[c]==0)
{
m++;
tr[p].son[c]=m;
}
build(v[x][i],x,tr[p].son[c]);
}
}
void get_answer(int x,int y,int p,int cnt,int ans[])
{
ans[x]=cnt;
cnt+=tr[p].cnt;
for(int i=0;i<v[x].size();i++)
{
if(v[x][i]==y) continue;
int c=ch[x][i]-'a';
get_answer(v[x][i],x,tr[p].son[c],cnt+tr[p].sum[c],ans);
}
}
int main()
{
scanf("%d%d",&n,&q);
for(int i=1;i<n;i++)
{
int a,b;
char tmp[10];
scanf("%d%d%s",&a,&b,tmp);
v[a].push_back(b);
ch[a].push_back(tmp[0]);
v[b].push_back(a);
ch[b].push_back(tmp[0]);
}
for(int i=1;i<=n;i++)
{
m=0;
memset(tr,0,sizeof(tr));
build(i,0,0);
for(int j=m;j>=0;j--)
{
tr[j].tot=tr[j].cnt;
for(int k=0;k<26;k++)
{
tr[j].sum[k+1]=tr[j].sum[k];
if(tr[j].son[k])
{
tr[j].tot+=tr[tr[j].son[k]].tot;
tr[j].sum[k+1]+=tr[tr[j].son[k]].tot;
}
}
}
get_answer(i,0,0,0,ans[i]);
}
while(q--){
int u,vv;
scanf("%d%d",&u,&vv);
printf("%d\n",ans[u][vv]-1 );
}
}