①先用dfs序或欧拉序把树转成区间,因为dfs序是前序遍历,根节点总在前面,而子树结点数,可以dfs时回溯求得,因此dfs序可以轻松维护子树的信息。
②然后用树状数组维护每个点到根的距离(通过前缀和求),这题用树状数组因为只涉及区间加减的修改,复杂度常数比线段树小。(但树状数组局限性大,不支持乘除等复杂的修改)
③最后如何求两点距离?树上差分(通俗理解,前缀和or子树和)的思想,如果是边权,减两个LCA。这里是点权,减lca和一个fa【lca】
几点小Tips:
补习链接:ST表 ST表+欧拉序求LCA
①用ST表求LCA一定用的是欧拉序,但我ST表忘了初始化,还有i、j写反莫名WA了几次…
②如果用dfs序,要记录子树结点数(也可维护欧拉序,记录第二次访问该结点得到的映射序列位置(也就是我代码中的out【】))
③w数组和树状数组,add参数一定要long long!不然莫名WA
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long ll;
const int maxn = 2000006;
struct {
int to, next, w;
} e[maxn << 2];
int head[maxn << 1], edgeNum;
void add(int u, int v) {
e[++edgeNum].to = v;
e[edgeNum].next = head[u];
head[u] = edgeNum;
}
int in[maxn], dfn[maxn], n, cnt, cnt2, m, R, deep[maxn], fa1[maxn];
ll w[maxn];
struct aaa {
ll S[maxn];
void add(int x, ll c) {
while (x <= n) S[x] += c, x += x & -x;
}
ll Sum(int x) {
ll res = 0;
while (x > 0) res += S[x], x -= x & -x;
return res;
}
} st1, st2;
// ll S[3][maxn];
int siz[maxn];
void init() {
// memset(head,-1,4*n+4);
cnt = edgeNum = 0;
}
// ST求LCA
int minl[25][maxn], lg[maxn];
int tmp;
inline void S_table() {
for (int i = 1; i <= cnt; ++i) lg[i] = (1 << (lg[i - 1] + 1)) <= i ? lg[i - 1] + 1 : lg[i - 1];
for (int j = 1; (1 << j) <= cnt; ++j)
for (register int i = 1; i + (1 << j) - 1 <= cnt; ++i) {
minl[j][i] = deep[minl[j - 1][i]] < deep[minl[j - 1][i + (1 << (j - 1))]]
? minl[j - 1][i]
: minl[j - 1][i + (1 << (j - 1))];
}
}
inline int lca(int l, int r) {
if (l > r)
swap(l, r);
int k = lg[r - l + 1];
// int k = log2((double)(r-l+1));
int mid = r - (1 << k) + 1;
// return min(minl[l][k],minl[mid][k]);
return deep[minl[k][l]] < deep[minl[k][r - (1 << k) + 1]] ? minl[k][l]
: minl[k][r - (1 << k) + 1];
}
ll query(int x) {
if (x == 0)
return 0;
return st1.Sum(dfn[x]) + (deep[x] + 1) * st2.Sum(dfn[x]) + w[x];
}
void dfs(int u) {
in[u] = ++cnt;
w[u] += w[fa1[u]];
deep[u] = deep[fa1[u]] + 1;
minl[0][cnt] = u;
dfn[u] = ++cnt2;
siz[u] = 1;
for (int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if (v == fa1[u])
continue;
fa1[v] = u;
dfs(v);
// out[u] = ++cnt;
minl[0][++cnt] = u;
siz[u] += siz[v];
}
}
int main() {
scanf("%d%d%d", &n, &m, &R);
for (int i = 1; i <= n; ++i) scanf("%lld", &w[i]);
for (int i = 0, u, v; i < n - 1; ++i) scanf("%d%d", &u, &v), add(u, v), add(v, u);
dfs(R); //等于自己dfs会加多
S_table();
for (int i = 0, p, u, v; i < m; ++i) {
scanf("%d%d%d", &p, &u, &v); // 0单点更值
if (p == 1)
st1.add(dfn[u], v), st1.add(dfn[u] + siz[u], -v);
else if (p == 2) { // 1 值*dep
st1.add(dfn[u], -1ll * v * (deep[u])), st1.add(dfn[u] + siz[u], 1ll * v * (deep[u]));
st2.add(dfn[u], v), st2.add(dfn[u] + siz[u], -v); // 2 个数
} else {
int L = lca(in[u], in[v]);
//区间查询,树上差分
printf("%lld\n", query(u) + query(v) - query(L) - query(fa1[L]));
}
// printf("%lld\n",query(u));//单点查询
}
}