这道题是真的妙,感觉做完这一题,我对LCT的理解又加深了一点。
题解:
操作一
我们发现这一个操作就是把树中某个节点到根节点的路径上的所有节点变成一样的颜色,又因为这棵树上有一个性质,同样颜色的点连接起来一定会是一条链,就可以想到LCT的access函数。所以我们将同种颜色的点看成LCT中同一棵splay上的点。
操作二
因为LCT已经维护了颜色了,所以我们不能再用LCT来维护路径之间的权值了。于是就可以想到树上差分,设 d i s [ x ] dis[x] dis[x]为点x的权值,这个值就等于在LCT上该点到根节点所经过虚边的条数+1,x到y路径的权值就是 d i s [ x ] + d i s [ y ] − 2 ∗ d i s [ l c a ] + 1 dis[x]+dis[y]-2*dis[lca]+1 dis[x]+dis[y]−2∗dis[lca]+1,这条公式不多解释,画一个图试试就出来了。
然后,我们就要考虑如何维护 d i s dis dis数组了。在access中,设当前节点为 x x x,它的父亲为 f a fa fa(不在同一棵splay中),因为我们将 x x x和 f a fa fa中原来的虚边变成了实边,所以x为根节点的子树中所有节点的 d i s dis dis-1。设 s o n son son为 f a fa fa原来的儿子,那么就会因为 s o n son son和 f a fa fa中的实边变成虚边使得以 s o n son son为根节点的子树中所以节点的 d i s dis dis-1。于是又现这些要+1或-1的节点都在同一棵子树里面,就很容易想到树的dfs序,我们按照这些点的dfs序建一颗线段树,每次就只需要修改连续的编号即可。
操作三
最复杂的操作二都搞定了,操作三还难吗,直接在线段树中求最大值不就行了呗。
参考代码
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
inline void read(int &x) {
x = 0; int f = 0; char s = getchar();
while (!isdigit(s)) f |= s=='-', s = getchar();
while ( isdigit(s)) x = x * 10 + s - 48, s = getchar();
x = f ? -x : x;
}
int ss = 0, buf[31];
inline void write(int x) {
do {buf[++ss] = x % 10, x /= 10;} while(x);
while (ss) putchar(buf[ss--]+'0'); puts("");
}
const int N = 1e5 + 6;
int L[N], R[N], rev[N], n, m, tot;//L,R,rev,tot用来维护树的dfs序
int fath[N][21], dep[N];//dep也可以看成最开始的dis
int cnt, ver[N<<1], Head[N], Next[N<<1];
namespace Tree {//线段树
#define lc p << 1
#define rc p << 1 | 1
struct T {
int l, r, mark, mx;
} tr[N<<2];
void pushup(int p) {
if (tr[p].mark) {
tr[lc].mark += tr[p].mark, tr[lc].mx += tr[p].mark;
tr[rc].mark += tr[p].mark, tr[rc].mx += tr[p].mark;
tr[p].mark = 0;
}
}
void build(int p, int l, int r) {
tr[p].l = l, tr[p].r = r;
if (l == r) tr[p].mx = dep[rev[l]];
else {
int mid = (l + r)>>1;
build(lc, l, mid), build(rc, mid + 1, r);
tr[p].mx = max(tr[lc].mx, tr[rc].mx);
}
}
int query(int p, int l, int r) {
if (r < tr[p].l || tr[p].r < l)
return 0;
if (l <= tr[p].l && tr[p].r <= r)
return tr[p].mx;
pushup(p);
int mid = (tr[p].l + tr[p].r)>>1;
return max(query(lc, l, r), query(rc, l, r));
}
void update(int p, int l, int r, int add) {
if (r < tr[p].l || tr[p].r < l) return;
if (l <= tr[p].l && tr[p].r <= r) {
tr[p].mark += add, tr[p].mx += add;
return;
}
int mid = (tr[p].l + tr[p].r)>>1;
pushup(p);
update(lc, l, r, add), update(rc, l, r, add);
tr[p].mx = max(tr[lc].mx, tr[rc].mx);
}
}
using namespace Tree;
namespace Link_Cut_Tree {//LCT
#define ls(x) t[x].son[0]
#define rs(x) t[x].son[1]
struct LCT {
int son[2], fa, val; bool mark;
} t[N<<1];
int isroot(int x) {
return (ls(t[x].fa) != x) && (rs(t[x].fa) != x);
}
void pushmark(int x) {
if (t[x].mark) {
swap(ls(x), rs(x));
t[x].mark = 0;
t[ls(x)].mark ^= 1, t[rs(x)].mark ^= 1;
}
}
void rotate(int x) {
int f = t[x].fa, ff = t[f].fa, qwq = (rs(t[x].fa) == x);
t[x].fa = ff;
if (!isroot(f)) t[ff].son[rs(ff)==f] = x;
t[f].son[qwq] = t[x].son[qwq^1];
if (t[x].son[qwq^1]) t[t[x].son[qwq^1]].fa = f;
t[x].son[qwq^1] = f, t[f].fa = x;
}
int st[N];
void splay(int x) {
int top = 0, now = x; st[++top] = now;
while (!isroot(now)) st[++top] = (now = t[now].fa);
while (top) pushmark(st[top--]);
while (!isroot(x)) {
int f = t[x].fa, ff = t[f].fa;
if (!isroot(f))
((rs(f) == x) ^ (rs(ff) == f)) ? rotate(x) : rotate(f);
rotate(x);
}
}
int findrt(int x) {//这里要特别注意一下,splay最左边的点(编号最小)才是真正的根
while (ls(x)) x = ls(x);
return x;
}
void access(int x) {//这里的access和普通的access稍微有一点差别
int son;
for (int y = 0; x; y = x, x = t[y].fa) {
splay(x);
if (rs(x)) son = findrt(rs(x)), update(1, L[son], R[son], 1);
if (rs(x) = y) son = findrt(y), update(1, L[son], R[son], -1);
}
}
}
using namespace Link_Cut_Tree;
void add(int x, int y) {
cnt++;
ver[cnt] = y;
Next[cnt] = Head[x];
Head[x] = cnt;
}
void dfs(int x, int fa) {
L[x] = ++tot, rev[tot] = x;
t[x].fa = fath[x][0] = fa, dep[x] = dep[fa] + 1;
for (int i = 1; i <= 19; i++)
fath[x][i] = fath[fath[x][i-1]][i-1];
for (int i = Head[x]; i; i = Next[i]) {
int y = ver[i];
if (y == fa) continue;
dfs(y, x);
}
R[x] = tot;
}
int LCA(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
for (int i = 19; i >= 0; i--)
if (dep[fath[x][i]] >= dep[y])
x = fath[x][i];
if (x == y) return x;
for (int i = 19; i >= 0; i--)
if (fath[x][i] != fath[y][i])
x = fath[x][i], y = fath[y][i];
return fath[x][0];
}
int main() {
read(n), read(m);
int opt, x, y, ans;
for (int i = 1; i < n; i++) {
read(x), read(y);
add(x, y), add(y, x);
}
dfs(1, 0); t[1].fa = 0;
build(1, 1, n);
for (int i = 1; i <= m; i++) {
read(opt), read(x);
if (opt == 1) access(x);
else if (opt == 2) {
read(y);
int lca = LCA(x, y);
int ans1 = query(1, L[x], L[x]);
int ans2 = query(1, L[y], L[y]);
int ans3 = query(1, L[lca], L[lca]);
write(ans1 + ans2 - 2 * ans3 + 1);
}
else {
ans = query(1, L[x], R[x]);
write(ans);
}
}
return 0;
}