Address
Solution
先树剖,把树转成序列。然后线段树上每个点维护
8
8
8 个值:
z
00
z_{00}
z00 :从对应区间左端点的
A
A
A 区域走到右端点的
A
A
A 区域经过的最多格子数,如果不连通则为
0
0
0 。
z
01
z_{01}
z01 :从对应区间左端点的
A
A
A 区域走到右端点的
B
B
B 区域经过的最多格子数,如果不连通则为
0
0
0 。
z
10
z_{10}
z10 :从对应区间左端点的
B
B
B 区域走到右端点的
A
A
A 区域经过的最多格子数,如果不连通则为
0
0
0 。
z
11
z_{11}
z11 :从对应区间左端点的
B
B
B 区域走到右端点的
B
B
B 区域经过的最多格子数,如果不连通则为
0
0
0 。
f
r
00
fr_{00}
fr00 :从对应区间左端点的
A
A
A 区域开始走经过的最多格子数。
f
r
01
fr_{01}
fr01 :从对应区间左端点的
B
B
B 区域开始走经过的最多格子数。
f
r
10
fr_{10}
fr10 :从对应区间右端点的
A
A
A 区域开始走经过的最多格子数。
f
r
11
fr_{11}
fr11 :从对应区间右端点的
B
B
B 区域开始走经过的最多格子数。
下面定义运算:
x
⨁
y
=
{
x
+
y
x
,
y
≠
0
0
ELSE
x\bigoplus y=\begin{cases}x+y&x,y\ne0\\0&\text{ELSE}\end{cases}
x⨁y={x+y0x,y̸=0ELSE
线段树左右子节点合并时(
l
c
lc
lc 和
r
c
rc
rc 分别为
u
u
u 的左右子节点):
z
00
[
u
]
=
max
(
z
00
[
l
c
]
⨁
z
00
[
r
c
]
,
z
01
[
l
c
]
⨁
z
10
[
r
c
]
)
z_{00}[u]=\max(z_{00}[lc]\bigoplus z_{00}[rc],z_{01}[lc]\bigoplus z_{10}[rc])
z00[u]=max(z00[lc]⨁z00[rc],z01[lc]⨁z10[rc])
z
01
[
u
]
=
max
(
z
00
[
l
c
]
⨁
z
01
[
r
c
]
,
z
01
[
l
c
]
⨁
z
11
[
r
c
]
)
z_{01}[u]=\max(z_{00}[lc]\bigoplus z_{01}[rc],z_{01}[lc]\bigoplus z_{11}[rc])
z01[u]=max(z00[lc]⨁z01[rc],z01[lc]⨁z11[rc])
z
10
[
u
]
=
max
(
z
10
[
l
c
]
⨁
z
00
[
r
c
]
,
z
11
[
l
c
]
⨁
z
10
[
r
c
]
)
z_{10}[u]=\max(z_{10}[lc]\bigoplus z_{00}[rc],z_{11}[lc]\bigoplus z_{10}[rc])
z10[u]=max(z10[lc]⨁z00[rc],z11[lc]⨁z10[rc])
z
11
[
u
]
=
max
(
z
10
[
l
c
]
⨁
z
01
[
r
c
]
,
z
11
[
l
c
]
⨁
z
11
[
r
c
]
)
z_{11}[u]=\max(z_{10}[lc]\bigoplus z_{01}[rc],z_{11}[lc]\bigoplus z_{11}[rc])
z11[u]=max(z10[lc]⨁z01[rc],z11[lc]⨁z11[rc])
f
r
00
[
u
]
=
max
(
f
r
00
[
l
c
]
,
max
(
z
00
[
l
c
]
⨁
f
r
00
[
r
c
]
,
z
01
[
l
c
]
⨁
f
r
01
[
r
c
]
)
)
fr_{00}[u]=\max(fr_{00}[lc],\max(z_{00}[lc]\bigoplus fr_{00}[rc],z_{01}[lc]\bigoplus fr_{01}[rc]))
fr00[u]=max(fr00[lc],max(z00[lc]⨁fr00[rc],z01[lc]⨁fr01[rc]))
f
r
01
[
u
]
=
max
(
f
r
01
[
l
c
]
,
max
(
z
10
[
l
c
]
⨁
f
r
00
[
r
c
]
,
z
11
[
l
c
]
⨁
f
r
01
[
r
c
]
)
)
fr_{01}[u]=\max(fr_{01}[lc],\max(z_{10}[lc]\bigoplus fr_{00}[rc],z_{11}[lc]\bigoplus fr_{01}[rc]))
fr01[u]=max(fr01[lc],max(z10[lc]⨁fr00[rc],z11[lc]⨁fr01[rc]))
f
r
10
[
u
]
=
max
(
f
r
10
[
r
c
]
,
max
(
z
00
[
r
c
]
⨁
f
r
10
[
l
c
]
,
z
10
[
r
c
]
⨁
f
r
11
[
l
c
]
)
)
fr_{10}[u]=\max(fr_{10}[rc],\max(z_{00}[rc]\bigoplus fr_{10}[lc],z_{10}[rc]\bigoplus fr_{11}[lc]))
fr10[u]=max(fr10[rc],max(z00[rc]⨁fr10[lc],z10[rc]⨁fr11[lc]))
f
r
11
[
u
]
=
max
(
f
r
11
[
r
c
]
,
max
(
z
01
[
r
c
]
⨁
f
r
10
[
l
c
]
,
z
11
[
r
c
]
⨁
f
r
11
[
l
c
]
)
)
fr_{11}[u]=\max(fr_{11}[rc],\max(z_{01}[rc]\bigoplus fr_{10}[lc],z_{11}[rc]\bigoplus fr_{11}[lc]))
fr11[u]=max(fr11[rc],max(z01[rc]⨁fr10[lc],z11[rc]⨁fr11[lc]))
合并两条链也是和上面类似的操作。
查询路径时,需要把这条路径剖成的重链分成两组,一组是向上走,另一组是向下走,分别查询之后合并起来。这时候就可能需要查询从区间右端点走到左端点而不是左端点走到右端点。
一个区间的方向取反之后,发生的变化为:
(1)
z
01
z_{01}
z01 和
z
10
z_{10}
z10 互换。
(2)
f
r
00
fr_{00}
fr00 和
f
r
10
fr_{10}
fr10 互换。
(3)
f
r
01
fr_{01}
fr01 和
f
r
11
fr_{11}
fr11 互换。
Code
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Rof(i, a, b) for (i = a; i >= b; i--)
#define Tree(u) for (int e = adj[u], v; e; e = nxt[e]) if ((v = go[e]) != fu)
#define p2 p << 1
#define p3 p << 1 | 1
using namespace std;
inline int read()
{
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
inline char get()
{
char c;
while ((c = getchar()) != 'C' && c != 'Q');
return c;
}
const int N = 5e4 + 5, M = N << 1, Z = 5, E = 105, L = M << 1;
int n, m, ecnt, nxt[M], adj[N], go[M], fa[N], dep[N], sze[N], son[N],
top[N], pos[N], idx[N], ToT, tot1, l1[E], r1[E], tot2, l2[E], r2[E];
char s[N][Z], t[Z];
void add_edge(int u, int v)
{
nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}
template <class T>
T Max(T a, T b) {return a > b ? a : b;}
template <class T>
T Add(T a, T b) {return a && b ? a + b : 0;}
struct node
{
int z00, z01, z10, z11, fr00, fr01, fr10, fr11;
friend inline node operator ! (node x)
{
return (node) {x.z00, x.z10, x.z01, x.z11,
x.fr10, x.fr11, x.fr00, x.fr01};
}
friend inline node operator + (node x, node y)
{
if (x.z00 == -1) return y;
node res = (node) {0, 0, 0, 0, 0, 0, 0, 0};
res.z00 = Max(Add(x.z00, y.z00), Add(x.z01, y.z10));
res.z01 = Max(Add(x.z00, y.z01), Add(x.z01, y.z11));
res.z10 = Max(Add(x.z10, y.z00), Add(x.z11, y.z10));
res.z11 = Max(Add(x.z10, y.z01), Add(x.z11, y.z11));
res.fr00 = Max(x.fr00, Max(Add(x.z00, y.fr00), Add(x.z01, y.fr01)));
res.fr01 = Max(x.fr01, Max(Add(x.z10, y.fr00), Add(x.z11, y.fr01)));
res.fr10 = Max(y.fr10, Max(Add(y.z00, x.fr10), Add(y.z10, x.fr11)));
res.fr11 = Max(y.fr11, Max(Add(y.z01, x.fr10), Add(y.z11, x.fr11)));
return res;
}
} T[L], tmp1[E], tmp2[E];
node orz(char x, char y)
{
return (node) {x == '.', x == '.' && y == '.' ? 2 : 0,
x == '.' && y == '.' ? 2 : 0, y == '.',
x == '.' ? (y == '.' ? 2 : 1) : 0,
y == '.' ? (x == '.' ? 2 : 1) : 0,
x == '.' ? (y == '.' ? 2 : 1) : 0,
y == '.' ? (x == '.' ? 2 : 1) : 0};
}
void dfs1(int u, int fu)
{
dep[u] = dep[fa[u] = fu] + (sze[u] = 1);
Tree(u)
{
dfs1(v, u);
sze[u] += sze[v];
if (sze[v] > sze[son[u]]) son[u] = v;
}
}
void dfs2(int u, int fu)
{
if (son[u])
{
top[son[u]] = top[u];
idx[pos[son[u]] = ++ToT] = son[u];
dfs2(son[u], u);
}
Tree(u)
{
if (v == son[u]) continue;
idx[pos[top[v] = v] = ++ToT] = v;
dfs2(v, u);
}
}
void build(int l, int r, int p)
{
if (l == r) return (void) (T[p] = orz(s[idx[l]][1], s[idx[l]][2]));
int mid = l + r >> 1;
build(l, mid, p2); build(mid + 1, r, p3);
T[p] = T[p2] + T[p3];
}
void initalize()
{
dfs1(1, 0);
top[1] = pos[1] = idx[1] = ToT = 1;
dfs2(1, 0); build(1, n, 1);
}
void change(int l, int r, int pos, char x, char y, int p)
{
if (l == r) return (void) (T[p] = orz(x, y));
int mid = l + r >> 1;
if (pos <= mid) change(l, mid, pos, x, y, p2);
else change(mid + 1, r, pos, x, y, p3);
T[p] = T[p2] + T[p3];
}
node ask(int l, int r, int s, int e, int p)
{
if (l == s && r == e) return T[p];
int mid = l + r >> 1;
if (e <= mid) return ask(l, mid, s, e, p2);
else if (s >= mid + 1) return ask(mid + 1, r, s, e, p3);
else return ask(l, mid, s, mid, p2)
+ ask(mid + 1, r, mid + 1, e, p3);
}
int query(int u, int v)
{
int i;
tot1 = tot2 = 0;
while (top[u] != top[v])
if (dep[top[u]] > dep[top[v]])
{
l1[++tot1] = pos[top[u]]; r1[tot1] = pos[u];
u = fa[top[u]];
}
else
{
l2[++tot2] = pos[top[v]]; r2[tot2] = pos[v];
v = fa[top[v]];
}
if (dep[u] < dep[v]) l2[++tot2] = pos[u], r2[tot2] = pos[v];
else l1[++tot1] = pos[v], r1[tot1] = pos[u];
For (i, 1, tot1) tmp1[i] = !ask(1, n, l1[i], r1[i], 1);
For (i, 1, tot2) tmp2[i] = ask(1, n, l2[i], r2[i], 1);
node res = (node) {-1, 0, 0, 0, 0, 0, 0, 0};
For (i, 1, tot1) res = res + tmp1[i];
Rof (i, tot2, 1) res = res + tmp2[i];
return Max(res.fr00, res.fr01);
}
int main()
{
int i, x, y;
char op;
n = read(); m = read();
For (i, 1, n - 1) x = read(), y = read(),
add_edge(x, y);
For (i, 1, n) scanf("%s", s[i] + 1);
initalize();
while (m--)
{
op = get(); x = read();
if (op == 'C') scanf("%s", t + 1),
change(1, n, pos[x], t[1], t[2], 1);
else y = read(), printf("%d\n", query(x, y));
}
return 0;
}