链接:https://ac.nowcoder.com/acm/contest/1083/F
来源:牛客网
ABCBA
时间限制:C/C++ 4秒,其他语言8秒
空间限制:C/C++ 51200K,其他语言102400K
64bit IO Format: %lld
题目描述
给出一颗n个结点n-1条边的树,再给出一个长度为n的字符串s,树上的每个点都表示一个字符,点i表示的字符是s[i],其只包含大写拉丁字符。再给出q个查询,对于每个查询,会给出两个整数u,v,表示树上的两个点。对于每个查询你将从点v开始走最短路径走到点u,并按行走的顺序连接每个结点上的字符,形成一个新的字符串H,你需要计算字符串H中包含子串‘ABCBA’的个数。子串的定义就是存在任意下标a<b<c<d<e,那么”s[a]s[b]s[c]s[d]s[e]”就构成s的一个子串。如”ABC”的子串有”A”、”B”、”C”、”AB”、”AC”、”BC”、”ABC”。
输入描述:
第一行两个数n,q。1<=n<=3e4,1<=q<=1e5。 第二行一个长度为n的字符串s。所有字符都为大写拉丁字符。 接下来n-1行每行两个数u,v表示点u和点v之间有一条边 接下来q行每行两个整数u,v。1<=u,v<=n。
输出描述:
对于每个查询输出一个整数表示点v到点u的路径上”ABCBA”子串的个数,每个答案占一行,答案对10007取模。
示例1
输入
复制
8 3 ABABCBAA 1 2 2 3 3 4 4 5 5 6 6 7 7 8 3 7 2 3 1 8
输出
复制
1 0 6
说明
对于查询3 7,从结点7走到结点3,形成的字符串为ABCBA,子串ABCBA的个数为1 对于查询1 8,从结点8走到结点3,形成的字符串为AABCBABA,子串ABCBA的个数为6
思路:
参考于橘子猫大神的博客:https://blog.csdn.net/ccsu_cat/article/details/100588528
%%%
没想到还可以这么做。
思路就是,用主席树维护点u到点lca(u,v)这一段区间的各种子串出现次数的值,并且只需要维护A,AB,ABC,ABCB,ABCBA,B,BC,BCB,BCBA,C,CB,CBA,BA这几个串的数量就可以了,我们可以发现,
ABCBA = ABCBA[ls] + ABCBA[rs] + A[ls] * BCBA[rs] + AB[ls] *CBA[rs] + ABC[ls] * BA[rs] + ABCB[ls] * A[rs]
所以这个题就变成了主席树+LCA的类型题了。。。
(太强了,完全没这方面的想法。。。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 3e4 + 10, mod = 10007;
vector<int> G[maxn];
char s[maxn];
void add(int &x, int y) {
x += y;
while (x >= mod)
x -= mod;
while (x < 0)
x += mod;
}
struct node {
int cat, A, AB, ABC, ABCB, B, BC, BCB, BCBA, C, CB, CBA, BA;
node operator+(const node &t) const {
node tmp;
tmp.cat = (cat + t.cat + A * t.BCBA + AB * t.CBA + ABC * t.BA + ABCB * t.A) % mod;
tmp.A = (A + t.A) % mod;
tmp.AB = (AB + t.AB + A * t.B) % mod;
tmp.ABC = (ABC + t.ABC + A * t.BC + AB * t.C) % mod;
tmp.ABCB = (ABCB + t.ABCB + A * t.BCB + AB * t.CB + ABC * t.B) % mod;
tmp.B = (B + t.B) % mod;
tmp.BC = (BC + t.BC + B * t.C) % mod;
tmp.BCB = (BCB + t.BCB + B * t.CB + BC * t.B) % mod;
tmp.BCBA = (BCBA + t.BCBA + B * t.CBA + BC * t.BA + BCB * t.A) % mod;
tmp.C = (C + t.C) % mod;
tmp.CB = (CB + t.CB + C * t.B) % mod;
tmp.CBA = (CBA + t.CBA + C * t.BA + CB * t.A) % mod;
tmp.BA = (BA + t.BA + B * t.A) % mod;
return tmp;
}
} tree[maxn * 20];
int rt[maxn], ls[maxn * 20], rs[maxn * 20], cnt, f[maxn][20], dep[maxn], n;
#define mid (l + r) / 2
void up(int &o, int pre, int l, int r, int k, char c) {
o = ++cnt;
ls[o] = ls[pre];
rs[o] = rs[pre];
if (l == r) {
if (c == 'A')
tree[o].A = 1;
else if (c == 'B')
tree[o].B = 1;
else if (c == 'C')
tree[o].C = 1;
return;
}
if (k <= mid)
up(ls[o], ls[pre], l, mid, k, c);
else
up(rs[o], rs[pre], mid + 1, r, k, c);
tree[o] = tree[ls[o]] + tree[rs[o]];
}
void dfs(int u, int fa) {
f[u][0] = fa;
dep[u] = dep[fa] + 1;
for (int i = 1; i < 18; i++)
f[u][i] = f[f[u][i - 1]][i - 1];
up(rt[u], rt[fa], 1, n, dep[u], s[u]);
for (auto v : G[u])
if (v != fa)
dfs(v, u);
}
int LCA(int u, int v) {
if (dep[u] < dep[v])
swap(u, v);
for (int i = 17; ~i; i--)
if (dep[f[u][i]] >= dep[v])
u = f[u][i];
if (u == v)
return u;
for (int i = 17; ~i; i--)
if (f[u][i] != f[v][i])
u = f[u][i], v = f[v][i];
return f[u][0];
}
node qu(int o, int l, int r, int ql, int qr) {
if (l >= ql && r <= qr)
return tree[o];
if (qr <= mid)
return qu(ls[o], l, mid, ql, qr);
else if (ql > mid)
return qu(rs[o], mid + 1, r, ql, qr);
else
return qu(ls[o], l, mid, ql, qr) + qu(rs[o], mid + 1, r, ql, qr);
}
int main() {
int m, u, v;
scanf("%d%d", &n, &m);
scanf("%s", s + 1);
for (int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1, 0);
while (m--) {
scanf("%d%d", &u, &v);
int lca = LCA(u, v);
int ans = 0;
if (u == lca || v == lca) {
node tmp;
if (u != lca)
tmp = qu(rt[u], 1, n, dep[lca], dep[u]);
else
tmp = qu(rt[v], 1, n, dep[lca], dep[v]);
add(ans, tmp.cat);
}
else {
node t1 = qu(rt[u], 1, n, dep[lca], dep[u]);
node t2 = qu(rt[v], 1, n, dep[lca] + 1, dep[v]);
ans = (t1.cat + t2.cat + t1.A * t2.BCBA + t1.BA * t2.CBA + t1.CBA * t2.BA + t1.BCBA * t2.A) % mod;
}
printf("%d\n", ans);
}
}