#1749 树的统计
题面
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
输入
输入文件的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来一行n个整数,第i个整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出
输出
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
样例输入
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
样例输出
4
1
2
2
10
6
5
6
5
16
SOL
树链剖分模板题。
引用一本通提高篇的解释:
f[x]:x的父亲
dep[x]:x的深度
siz[x]:x的子树大小
son[x]:x的重儿子
top[x]:x所在重路径中深度最小的节点
seg[x]:线段树的第x个位置对应的树的节点
rev[x]:rev[seg[x]]=x;
具体做法详见代码。
代码:
#include<bits/stdc++.h>
#define inf 0x3f3f3f3f
#define N 120025
#define M 30005
using namespace std;
inline int rd(){
int data=0,w=1;static char ch=0;ch=getchar();
while(ch!='-'&&(!isdigit(ch)))ch=getchar();
if(ch=='-')w=-1,ch=getchar();
while(isdigit(ch))data=(data<<1)+(data<<3)+ch-'0',ch=getchar();
return data*w;
}
int cnt,first[M];
struct node{int v,nxt;}e[M<<1];
inline void add(int u,int v){e[++cnt].v=v;e[cnt].nxt=first[u];first[u]=cnt;}
int n,q,w[M],summ,maxx;
int f[N],dep[N],siz[N],son[N],top[N],seg[N],rev[N];
void dfs1(int u,int fa){
siz[u]=1;
f[u]=fa;
dep[u]=dep[fa]+1;
for(int register i=first[u];i;i=e[i].nxt){
int register v=e[i].v;
if(v!=fa){
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
}
void dfs2(int u){
if(son[u]){
seg[son[u]]=++seg[0];
rev[seg[0]]=son[u];
top[son[u]]=top[u];
dfs2(son[u]);
}
for(int register i=first[u];i;i=e[i].nxt){
int v=e[i].v;
if(!top[v]){
seg[v]=++seg[0];
rev[seg[0]]=v;
top[v]=v;
dfs2(v);
}
}
}
#define lc (p<<1)
#define rc (p<<1|1)
struct segment_tree{int sum,mx;}t[N];
inline void pushup(int p){
t[p].sum=t[lc].sum+t[rc].sum;
t[p].mx=max(t[lc].mx,t[rc].mx);
}
void build(int p,int l,int r){//cerr<<"F**k\n";
if(l==r){t[p].sum=t[p].mx=w[rev[l]];return;}
int mid=l+r>>1;
build(lc,l,mid);
build(rc,mid+1,r);
pushup(p);
}
void change(int p,int l,int r,int v,int pos){
if(pos>r||pos<l)return;
if(l==r&&r==pos){t[p].sum=t[p].mx=v;return;}
int mid=l+r>>1;
if(pos<=mid)change(lc,l,mid,v,pos);
if(pos>mid)change(rc,mid+1,r,v,pos);
pushup(p);
}
void query(int p,int l,int r,int ql,int qr){
if(ql>r||qr<l)return;
if(ql<=l&&qr>=r){
summ+=t[p].sum;
maxx=max(maxx,t[p].mx);
return;
}
int mid=l+r>>1;
if(ql<=mid)query(lc,l,mid,ql,qr);
if(qr>mid)query(rc,mid+1,r,ql,qr);
}
void ret(int x,int y){
int fx=top[x],fy=top[y];
while(fx!=fy){
if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy);
query(1,1,seg[0],seg[fx],seg[x]);
x=f[fx];fx=top[x];
}
if(dep[x]>dep[y])swap(x,y);
query(1,1,seg[0],seg[x],seg[y]);
}
char ins[10];
int main(){
n=rd();
for(int register i=1;i<n;i++){
int x=rd(),y=rd();
add(x,y);add(y,x);
}
for(int register i=1;i<=n;i++)w[i]=rd();
dfs1(1,0);
seg[0]=seg[1]=top[1]=rev[1]=1;
dfs2(1);
build(1,1,seg[0]);
q=rd();
while(q--){
scanf("%s",ins);
int register u=rd(),v=rd();
if(ins[1]=='H'){
change(1,1,seg[0],v,seg[u]);
}
if(ins[0]=='Q'){
maxx=-inf;summ=0;ret(u,v);
if(ins[1]=='M')printf("%d\n",maxx);
if(ins[1]=='S')printf("%d\n",summ);
}
}
return 0;
}