https://www.luogu.com.cn/problem/P4842
发现题目就是要维护一个LCT,然后我们只要把pushup写成功了就行。
那我们现在就不管LCT了,就单纯想用一棵二叉查找树怎么维护。分母是好搞的,分子我们要想点办法。
考虑右子树对左子树的贡献,我们假设处理出一个 L [ k ] L[k] L[k] 表示左子树中每个值乘以左边界的可选数量,我们现在再乘上右子树的大小就成功了。
那么pushup就很好写了,现在就是整棵树加 x x x 的问题,但这个东西看起来很复杂,但其实我们只需要把转移系数仔细算算就出来了。
一个要注意的地方是,整体reverse的时候要交换 L [ k ] , R [ k ] L[k],R[k] L[k],R[k]
//5k
#include<bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stdout, ##__VA_ARGS__)
#define debag(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define debug(...) void(0)
#define debag(...) void(0)
#endif
#define int long long
inline int read(){int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;
ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+
(x<<3)+(ch^48);ch=getchar();}return x*f;}
#define Z(x) (x)*(x)
#define pb push_back
#define fi first
#define se second
int F1(int n) {
return (n * (n + 1) / 2 + n * (n + 1) * (2 * n + 1) / 6) / 2; }
int F3(int n) { return n * (n + 1) * (n + 2) / 6; }
int F2(int n) { return n * (n + 1) / 2; }
//#define M
//#define mo
#define N 50010
int n, m, i, j, k, T;
namespace LCT {
#define ls(x) (son[x][0])
#define rs(x) (son[x][1])
int i, j, k;
int a[N], s[N], L[N], R[N], ans[N], tag[N], son[N][2], w[N];
int rev[N], fa[N];
void Add(int k, int x) {
a[k] += x; s[k] += x * w[k]; tag[k] += x;
ans[k] += x * F3(w[k]);
L[k] += x * F2(w[k]); R[k] += x * F2(w[k]);
}
int zh(int x) {
return rs(fa[x]) == x;
}
void push_down(int k) {
if(rev[k]) {
swap(ls(k), rs(k));
swap(L[k], R[k]);
if(ls(k)) rev[ls(k)] ^= 1;
if(rs(k)) rev[rs(k)] ^= 1;
rev[k] = 0;
}
if(tag[k]) {
if(ls(k)) Add(ls(k), tag[k]);
if(rs(k)) Add(rs(k), tag[k]);
tag[k] = 0;
}
}
void push_up(int k) {
if(!k) return ;
int l = ls(k), r = rs(k);
push_down(k);
if(l) push_down(l);
if(r) push_down(r);
s[k] = s[l] + s[r] + a[k];
w[k] = w[l] + w[r] + 1;
ans[k] = ans[l] + ans[r] + a[k] * (w[l] + 1) * (w[r] + 1);
ans[k] += L[l] * (1 + w[r]) + R[r] * (1 + w[l]);
L[k] = L[l] + L[r] + (a[k] + s[r]) * (w[l] + 1);
R[k] = R[r] + R[l] + (a[k] + s[l]) * (w[r] + 1);
}
bool isRoot(int x) {
if(!fa[x]) return x;
return ls(fa[x]) != x && rs(fa[x]) != x;
}
void Rotate(int x) {
int y = fa[x], z = fa[y];
int k = zh(x), w = son[x][k ^ 1];
// push_up(z); push_up(y); push_up(x); push_up(w);
if(z && !isRoot(y)) son[z][zh(y)] = x;
if(w) son[y][k] = w, fa[w] = y; else son[y][k] = 0;
fa[x] = z; fa[y] = x; son[x][k ^ 1] = y;
push_up(w); push_up(y); push_up(x); push_up(z);
}
void Splay(int x) {
stack<int>sta; sta.push(x);
for(int y = x; !isRoot(y); y = fa[y]) sta.push(fa[y]);
while(!sta.empty()) push_up(sta.top()), sta.pop();
while(!isRoot(x)) {
int y = fa[x];
if(!isRoot(y)) {
if(zh(x) == zh(y)) Rotate(y);
else Rotate(x);
}
Rotate(x);
}
}
void access(int x) {
for(int y = 0; x; y = x, x = fa[x]) {
Splay(x); rs(x) = y; push_up(y); push_up(x);
}
}
void makeRoot(int x) {
access(x); Splay(x); rev[x] ^= 1; push_up(x);
}
int findRoot(int x) {
access(x); Splay(x);
while(push_up(x), ls(x)) x = ls(x);
Splay(x);
return x;
}
void Link(int x, int y) {
makeRoot(x); fa[x] = y;
}
void Cut(int x, int y) {
makeRoot(x); access(y); Splay(y);
fa[x] = son[y][0] = 0; push_up(y); push_up(x);
}
bool check(int x, int y) {
makeRoot(x);
return findRoot(y) == x;
}
bool find_edge(int x, int y) {
makeRoot(x); access(y); Splay(y);
return w[y] == 2;
}
pair<int, int> Qans(int x, int y) {
makeRoot(x); access(y); Splay(y);
// for(i = 1; i <= n; ++i) push_down(i), push_up(i);
// for(i = 1; i <= n; ++i) push_down(i), push_up(i);
return {ans[y], w[y]};
}
void print() {
#ifdef LOCAL
for(int i = 0; i <= n; ++i)
debug("> %lld %lld(%lld) [%lld %lld] %lld %lld %lld %lld %lld\n",
i, fa[i], (int)isRoot(i), ls(i), rs(i), w[i], s[i], L[i], R[i], ans[i]);
debug("-------------\n");
#endif
}
}
signed main()
{
#ifdef LOCAL
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
// srand(time(NULL));
// T=read();
// while(T--) {
//
// }
int op, u, v;
n = read(); m = read();
for(i = 1; i <= n; ++i) LCT :: w[i] = 1;
for(i = 1; i <= n; ++i) {
k = read(); LCT :: Add(i, k);
}
for(i = 1; i < n; ++i) {
u = read(); v = read();
LCT :: Link(u, v);
}
// LCT :: print();
for(i = 1; i <= m; ++i) {
op = read(); u = read(); v = read();
// assert(u != v);
if(op == 1) {
if(u == v) continue;
if(LCT :: find_edge(u, v)) LCT :: Cut(u, v);
}
if(op == 2) {
if(u == v) continue;
if(!LCT :: check(u, v)) LCT :: Link(u, v);
}
if(op == 3) {
k = read();
if(LCT :: check(u, v)) {
LCT :: makeRoot(u); LCT :: access(v);
LCT :: Splay(u); LCT :: Add(u, k);
}
}
if(op == 4) {
if(!LCT :: check(u, v)) { printf("-1\n"); continue; }
auto t = LCT :: Qans(u, v);
debug("[%lld %lld] %lld\n", t.fi, t.se, F2(t.se));
t.se = F2(t.se); k = __gcd(t.fi, t.se);
t.fi /= k; t.se /= k;
printf("%lld/%lld\n", t.fi, t.se);
}
LCT :: print();
// printf("===============\n");
}
return 0;
}