Description
给定一个n个点的树,树上每个点有一个字符,再给一个长度为m的串。
两点的价值为:两点连接形成的字符串再m串中出现的次数。
询问两两点价值的和。
Sample Input
3 5
1 2
1 3
aab
abaab
Sample Output
15
首先考虑点分治。
然后再考虑根号分治
对于一个点分治块,假如它的大小小于等于
n
\sqrt n
n,那么直接对于每一个点暴力去做它可以形成的所有串,可以用自动机在dfs中
O
(
1
)
O(1)
O(1)维护一个串出现的次数。
这一部分时间复杂度不超过
O
(
n
n
)
O(n\sqrt n)
O(nn)。
对于大于等于
n
\sqrt n
n的块,我们只用考虑经过分治中心的点即可。
对于一条路径(x,y),设分治中心为g,则我们考虑对于一个m串中的位置p的贡献。我们算出(x,g)经过了p的次数,(g,y)经过了p的次数,相乘即可。
先考虑(x,g),我们可以建出一个parent树,从g往下走,就相当于每次都在前面插个字符,那就相当于直接在parent树上走,边走边打标记,最后再跑一边parent树累加标记即可。
对于(g,y)就相当于在后缀树上做。
因为所有这样的块不超过
n
\sqrt n
n个,所以这样的时间复杂度为
O
(
n
n
)
O(n\sqrt n)
O(nn)。
#include <cmath>
#include <ctime>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010;
int _max(int x, int y) {return x > y ? x : y;}
int _min(int x, int y) {return x < y ? x : y;}
int read() {
int s = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
return s * f;
}
struct edge {
int x, y, next;
} e[2 * N]; int len, n, m, last[N];
char yy[N], tt[N];
bool v[N];
struct SAM {
int cnt, lst, fa[N * 2], len[N * 2], right[N * 2], ch[N * 2][26], son[N * 2][26];
int s[N], size[N], id[N * 2], ss[N], tag[N];
void ins(int c, int i) {
int np = ++cnt, p = lst; ss[i] = c;
len[np] = len[p] + 1; right[np] = i; size[np] = 1;
while(p && !ch[p][c]) ch[p][c] = np, p = fa[p];
if(!p) fa[np] = 1;
else {
int q = ch[p][c];
if(len[p] + 1 == len[q]) fa[np] = q;
else {
int nq = ++cnt;
fa[nq] = fa[q]; len[nq] = len[p] + 1;
memcpy(ch[nq], ch[q], sizeof(ch[nq]));
fa[np] = fa[q] = nq;
while(p && ch[p][c] == q) ch[p][c] = nq, p = fa[p];
}
} lst = np;
}
void bt() {
for(int i = 1; i <= cnt; i++) ++s[len[i]];
for(int i = 1; i <= m; i++) s[i] += s[i - 1];
for(int i = cnt; i >= 1; i--) id[s[len[i]]--] = i;
for(int i = cnt; i >= 1; i--) {
int x = id[i], F = fa[x];
size[F] += size[x];
right[F] = _max(right[F], right[x]);
son[F][ss[right[x] - len[F]]] = x;
}
}
void work(int x, int fa, int now, int ll) {
if(ll == len[now]) now = son[now][tt[x]];
else if(ss[right[now] - ll] != tt[x]) now = 0;
if(!now) return ; ++tag[now];
for(int k = last[x]; k; k = e[k].next) {
int y = e[k].y;
if(y != fa && !v[y]) work(y, x, now, ll + 1);
}
}
void push() {
for(int i = 1; i <= cnt; i++) tag[id[i]] += tag[fa[id[i]]];
}
} sam1, sam2;
int sum, block, tot[N], f[N];
int tp, now, sta[N];
int id1[N], id2[N];
LL ans;
void ins(int x, int y) {
e[++len].x = x, e[len].y = y;
e[len].next = last[x], last[x] = len;
}
int getrt(int x, int fa) {
int p = 0;
tot[x] = 1; f[x] = 0;
for(int k = last[x]; k; k = e[k].next) {
int y = e[k].y;
if(y != fa && !v[y]) {
int hh = getrt(y, x);
if(f[hh] < f[p]) p = hh;
tot[x] += tot[y];
f[x] = _max(f[x], tot[y]);
}
} f[x] = _max(f[x], sum - tot[x]);
if(f[x] < f[p]) p = x;
return p;
}
void get(int x, int fa) {
sta[++tp] = x;
for(int k = last[x]; k; k = e[k].next) {
int y = e[k].y;
if(y != fa && !v[y]) get(y, x);
}
}
void getsum(int x, int fa) {
if(!now) return ;
int hh = now; ans += sam1.size[now];
for(int k = last[x]; k; k = e[k].next) {
int y = e[k].y;
if(y != fa && !v[y]) {
now = sam1.ch[now][tt[y]], getsum(y, x);
now = hh;
}
}
}
void calc(int x, int fa, int o) {
for(int i = 1; i <= sam1.cnt; i++) sam1.tag[i] = 0;
for(int i = 1; i <= sam2.cnt; i++) sam2.tag[i] = 0;
if(o == 1) sam1.work(x, 0, 1, 0), sam2.work(x, 0, 1, 0);
else sam1.work(x, 0, sam1.son[1][tt[fa]], 1), sam2.work(x, 0, sam2.son[1][tt[fa]], 1);
sam1.push(), sam2.push();
for(int i = 1; i <= m; i++) ans += (LL)o * sam1.tag[id1[i]] * sam2.tag[id2[m - i + 1]];
}
void solve(int x) {
if(sum <= block) {
tp = 0; get(x, 0);
for(int i = 1; i <= tp; i++) now = sam1.ch[1][tt[sta[i]]], getsum(sta[i], 0);
return ;
} v[x] = 1;
calc(x, 0, 1);
for(int k = last[x]; k; k = e[k].next) {
int y = e[k].y;
if(!v[y]) {
calc(y, x, -1);
sum = tot[y], solve(getrt(y, 0));
}
}
}
int main() {
n = read(), m = read();
for(int i = 1; i < n; i++) {
int x = read(), y = read();
ins(x, y), ins(y, x);
} scanf("%s", tt + 1);
scanf("%s", yy + 1);
sam1.cnt = sam2.cnt = sam1.lst = sam2.lst = 1;
for(int i = 1; i <= m; i++) yy[i] -= 'a', sam1.ss[i] == yy[i], sam1.ins(yy[i], i), id1[i] = sam1.lst;
reverse(yy + 1, yy + m + 1);
for(int i = 1; i <= m; i++) sam2.ss[i] == yy[i], sam2.ins(yy[i], i), id2[i] = sam2.lst;
sam1.bt(), sam2.bt();
f[0] = 999999999; block = sqrt(n);
for(int i = 1; i <= n; i++) tt[i] -= 'a';
sum = n, solve(getrt(1, 0));
printf("%lld\n", ans);
return 0;
}