一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 III. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。 对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
题解:可以用树链剖分来做,将其分解成线段,用线段树来维护。输入节点值的时候应该注意,数组对应得下标应该以新的编号为准,然后再建线段树。
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <string>
#include <math.h>
#include <stdlib.h>
#include <time.h>
#include <bitset>
#define INF 0x3f3f3f3f
#define eps 1e-6
#define PI 3.1415926
#define mod 1000000009
#define base 2333
using namespace std;
typedef long long LL;
const int maxn = 3e4 + 10;
const int maxx = 1e3 + 10;
inline void splay(int &v) {
v=0;char c=0;int p=1;
while(c<'0' || c >'9'){if(c=='-')p=-1;c=getchar();}
while(c>='0' && c<='9'){v=(v<<3)+(v<<1)+c-'0';c=getchar();}
v*=p;
}
int n, a, b, q, len, cnt, x, y, head[maxn], dep[maxn], siz[maxn];
int fa[maxn], son[maxn], id[maxn], val[maxn], top[maxn];
struct node {
int to, next;
} e[maxn<<1];
struct Tree {
int l, r, m;
int mx, sum;
} tr[maxn<<2];
void add(int from, int to) {
e[len].to = to;
e[len].next = head[from];
head[from] = len++;
}
void dfs1(int u, int ff, int deep) {
dep[u] = deep, fa[u] = ff;
son[u] = 0, siz[u] = 1;
for(int i = head[u]; i != -1; i = e[i].next) {
int v = e[i].to;
if(v == ff) continue;
dfs1(v, u, deep+1);
siz[u] += siz[v];
if(siz[son[u]] < siz[v])
son[u] = v;
}
}
void dfs2(int u, int tp) {
top[u] = tp;
id[u] = ++cnt;
if(son[u]) dfs2(son[u], tp);
for(int i = head[u]; i != -1; i = e[i].next) {
int v = e[i].to;
if(v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
void PushUp(int id) {
tr[id].mx = max(tr[id<<1].mx, tr[id<<1|1].mx);
tr[id].sum = tr[id<<1].sum+tr[id<<1|1].sum;
}
void build(int id, int l, int r) {
tr[id].l = l, tr[id].r = r;
tr[id].m = (l+r)>>1;
if(l == r) tr[id].mx = tr[id].sum = val[l];
else {
build(id<<1, l, tr[id].m);
build(id<<1|1, tr[id].m+1, r);
PushUp(id);
}
}
void update(int id, int l, int r, int v) {
if(l <= tr[id].l && r >= tr[id].r)
tr[id].mx = tr[id].sum = v;
else {
if(l <= tr[id].m) update(id<<1, l, r, v);
if(r > tr[id].m) update(id<<1|1, l, r, v);
PushUp(id);
}
}
int Query_Max(int id, int l, int r) {
if(l <= tr[id].l && r >= tr[id].r)
return tr[id].mx;
else {
int mx1 = -INF, mx2 = -INF;
if(l <= tr[id].m) mx1 = Query_Max(id<<1, l, r);
if(r > tr[id].m) mx2 = Query_Max(id<<1|1, l, r);
return max(mx1, mx2);
}
}
int Query_Sum(int id, int l, int r) {
if(l <= tr[id].l && r >= tr[id].r)
return tr[id].sum;
else {
int sum1 = 0, sum2 = 0;
if(l <= tr[id].m) sum1 = Query_Sum(id<<1, l, r);
if(r > tr[id].m) sum2 = Query_Sum(id<<1|1, l, r);
return sum1+sum2;
}
}
int Find(int u, int v, int op) {
int tp1 = top[u], tp2 = top[v], ans = 0;
if(op) ans = -INF;
while(tp1 != tp2) {
if(dep[tp1] < dep[tp2]) {
swap(tp1, tp2);
swap(u, v);
}
if(op) ans = max(ans, Query_Max(1, id[tp1], id[u]));
else ans += Query_Sum(1, id[tp1], id[u]);
u = fa[tp1], tp1 = top[u];
}
if(dep[u] > dep[v]) swap(u, v);
if(op) ans = max(ans, Query_Max(1, id[u], id[v]));
else ans += Query_Sum(1, id[u], id[v]);
return ans;
}
void solve() {
splay(n);
memset(head, -1, sizeof(head));
len = 0, cnt = 0;
for(int i = 1; i < n; i++) {
splay(a), splay(b);
add(a, b), add(b, a);
}
dfs1(1, 0, 1);
dfs2(1, 1);
for(int i = 1; i <= n; i++)
splay(val[id[i]]);
build(1, 1, n);
splay(q);
char str[10];
while(q--) {
scanf("%s", str);
splay(x), splay(y);
if(str[0] == 'C')
update(1, id[x], id[x], y);
else {
int ans = 0;
if(str[1] == 'M') ans = Find(x, y, 1);
if(str[1] == 'S') ans = Find(x, y, 0);
printf("%d\n", ans);
}
}
}
int main() {
//srand(time(NULL));
//freopen("kingdom.in","r",stdin);
//freopen("kingdom.out","w",stdout);
solve();
}