题目链接:http://www.lydsy.com/JudgeOnline/problem.php?id=1036
题意:略。
题解:树链剖分模版,注意一些细节即可。
#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
const int M = 3e4 + 10;
struct Edge {
int v , next;
}edge[M << 1];
int head[M] , e;
int top[M];
int fa[M];
int p[M];
int fp[M];
int deep[M];
int num[M];
int son[M];
int pos;
void init() {
memset(head , -1 , sizeof(head));
memset(son , -1 , sizeof(son));
e = 0;
pos = 1;
}
void add(int u , int v) {
edge[e].v = v;
edge[e].next = head[u];
head[u] = e++;
}
void dfs1(int u , int pre , int d) {
deep[u] = d;
fa[u] = pre;
num[u] = 1;
for(int i = head[u] ; i != -1 ; i = edge[i].next) {
int v = edge[i].v;
if(v != pre) {
dfs1(v , u , d + 1);
num[u] += num[v];
if(son[u] == -1 || num[son[u]] < num[v]) {
son[u] = v;
}
}
}
}
void getpos(int u , int sp) {
top[u] = sp;
p[u] = pos++;
fp[p[u]] = u;
if(son[u] == -1) return ;
getpos(son[u] , sp);
for(int i = head[u] ; i != -1 ; i = edge[i].next) {
int v = edge[i].v;
if(v != fa[u] && v != son[u])
getpos(v , v);
}
}
struct TnT {
int l , r , sum , MAX;
}T[M << 2];
int a[M];
void pushup(int i) {
T[i].sum = T[i << 1].sum + T[(i << 1) | 1].sum;
T[i].MAX = max(T[i << 1].MAX , T[(i << 1) | 1].MAX);
}
void build(int l , int r , int i) {
int mid = (l + r) >> 1;
T[i].l = l , T[i].r = r , T[i].MAX = 0 , T[i].sum = 0;
if(l == r) {
T[i].MAX = a[fp[l]];
T[i].sum = a[fp[l]];
return ;
}
build(l , mid , i << 1);
build(mid + 1 , r , (i << 1) | 1);
pushup(i);
}
void updata(int i , int pos , int ad) {
int mid = (T[i].l + T[i].r) >> 1;
if(T[i].l == T[i].r && T[i].l == pos) {
T[i].MAX = ad;
T[i].sum = ad;
return ;
}
if(mid < pos) {
updata((i << 1) | 1 , pos , ad);
}
else {
updata(i << 1 , pos , ad);
}
pushup(i);
}
int queryM(int l , int r , int i) {
int mid = (T[i].l + T[i].r) >> 1;
if(T[i].l == l && T[i].r == r) {
return T[i].MAX;
}
pushup(i);
if(mid < l) {
return queryM(l , r , (i << 1) | 1);
}
else if(mid >= r) {
return queryM(l , r , i << 1);
}
else {
return max(queryM(l , mid , i << 1) , queryM(mid + 1 , r , (i << 1) | 1));
}
}
int queryS(int l , int r , int i) {
int mid = (T[i].l + T[i].r) >> 1;
if(T[i].l == l && T[i].r == r) {
return T[i].sum;
}
pushup(i);
if(mid < l) {
return queryS(l , r , (i << 1) | 1);
}
else if(mid >= r) {
return queryS(l , r , i << 1);
}
else {
return queryS(l , mid , i << 1) + queryS(mid + 1 , r , (i << 1) | 1);
}
}
int findM(int u , int v) {
int f1 = top[u] , f2 = top[v];
int tmp = -30010;
while(f1 != f2) {
if(deep[f1] < deep[f2]) {
swap(f1 , f2);
swap(u , v);
}
tmp = max(tmp , queryM(p[f1] , p[u] , 1));
u = fa[f1] , f1 = top[u];
}
if(deep[u] > deep[v]) swap(u , v);
return max(tmp , queryM(p[u] , p[v] , 1));
}
int findS(int u , int v) {
int f1 = top[u] , f2 = top[v];
int tmp = 0;
while(f1 != f2) {
if(deep[f1] < deep[f2]) {
swap(f1 , f2);
swap(u , v);
}
tmp += queryS(p[f1] , p[u] , 1);
u = fa[f1] , f1 = top[u];
}
if(deep[u] > deep[v]) swap(u , v);
return tmp + queryS(p[u] , p[v] , 1);
}
int main() {
int n , u , v , m;
scanf("%d" , &n);
init();
for(int i = 0 ; i < n - 1 ; i++) {
scanf("%d%d" , &u , &v);
add(u , v);
add(v , u);
}
for(int i = 1 ; i <= n ; i++) {
scanf("%d" , &a[i]);
}
dfs1(1 , 0 , 0);
getpos(1 , 1);
build(1 , pos , 1);
scanf("%d" , &m);
char cp[10];
while(m--) {
scanf("%s" , cp);
scanf("%d%d" , &u , &v);
if(cp[0] == 'Q') {
if(cp[1] == 'M') {
printf("%d\n" , findM(u , v));
}
else {
printf("%d\n" , findS(u , v));
}
}
else {
updata(1 , p[u] , v);
}
}
return 0;
}