题目 https://cn.vjudge.net/problem/HYSBZ-1036
树链剖分
#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <map>
#include <cstring>
#include <vector>
using namespace std;
#define lson o<<1
#define rson o<<1|1
#define MID int m = (l+r)/2
const int inf = 0x3f3f3f3f;
const int maxn = 300000 + 100;
int n;
vector<int> edge[maxn];
int data[maxn];
int cnt;
int fa[maxn],deep[maxn];
int siz[maxn],son[maxn],top[maxn],tid[maxn];
int id_data[maxn];
struct Info
{
int sum,MAX;
}tree[maxn*10];
void build(int o,int l,int r)
{
if(l == r)
{
tree[o].sum = id_data[l];
tree[o].MAX = id_data[l];
return ;
}
MID;
build(lson,l,m);
build(rson,m+1,r);
tree[o].sum = tree[lson].sum + tree[rson].sum;
tree[o].MAX = max(tree[lson].MAX,tree[rson].MAX);
}
void updata(int o,int l, int r, int x, int y)
{
if(l > x || r < x) return ;
if(l == x && r == x)
{
tree[o].sum = y;
tree[o].MAX = y;
return ;
}
MID;
updata(lson,l,m,x,y);
updata(rson,m+1,r,x,y);
tree[o].sum = tree[lson].sum + tree[rson].sum;
tree[o].MAX = max(tree[lson].MAX,tree[rson].MAX);
}
int query1(int o, int l, int r, int ul, int ur)
{
if(ul > r || ur < l) return 0;
if(ul <= l && r <= ur)
{
return tree[o].sum;
}
MID;
return query1(lson,l,m,ul,ur) + query1(rson,m+1,r,ul,ur);
}
int query2(int o, int l, int r, int ul, int ur)
{
if(ul > r || ur < l) return -333333;
if(ul <= l && r <= ur)
{
return tree[o].MAX;
}
MID;
return max(query2(lson,l,m,ul,ur) , query2(rson,m+1,r,ul,ur));
}
int Query1(int x,int y)
{
if(x == y) return query1(1,1,n,tid[x],tid[x]);
int tx = top[x],ty = top[y];
int ans = 0;
while(tx != ty)
{
if(deep[tx] < deep[ty]) swap(x,y),swap(tx,ty);
ans = ans + query1(1,1,n,tid[tx],tid[x]);
x = fa[tx],tx = top[x];
}
if(deep[x] > deep[y]) swap(x,y);
ans = ans + query1(1,1,n,tid[x],tid[y]);
return ans;
}
int Query2(int x,int y)
{
if(x == y) return query2(1,1,n,tid[x],tid[x]);
int tx = top[x],ty = top[y];
int ans = -inf;
while(tx != ty)
{
if(deep[tx] < deep[ty]) swap(x,y),swap(tx,ty);
ans = max(ans,query2(1,1,n,tid[tx],tid[x]));
x = fa[tx],tx = top[x];
}
if(deep[x] > deep[y]) swap(x,y);
ans = max(ans,query2(1,1,n,tid[x],tid[y]));
return ans;
}
void dffs(int u,int f,int d)
{
fa[u] = f,deep[u] = d;
siz[u] = 1,son[u] = -1;
for(int i = 0; i < edge[u].size(); i++)
{
int v = edge[u][i];
if(v != f)
{
dffs(v,u,d+1);
siz[u] += siz[v];
if(son[u] == -1||siz[son[u]] < siz[v])
{
son[u] = v;
}
}
}
}
void dfss(int u,int t)
{
tid[u] = ++cnt;
top[u] = t;
id_data[cnt] = data[u];
if(son[u] != -1)
{
dfss(son[u],t);
}
for(int i = 0;i<edge[u].size();i++)
{
int v = edge[u][i];
if(son[u] != v && fa[u] != v) dfss(v,v);
}
}
int main()
{
scanf("%d", &n);
for(int i=1;i<=n;i++) edge[i].clear();
for(int i = 1; i < n; i++)
{
int u,v;
scanf("%d %d", &u, &v);
edge[u].push_back(v);
edge[v].push_back(u);
}
for(int i = 1;i <= n;i++)
{
scanf("%d",&data[i]);
}
cnt = 0;
dffs(1,-1,0);
dfss(1,1);
build(1,1,n);
int t;
scanf("%d", &t);
char str[10];
while(t--)
{
scanf("%s",str);
int a, b, c;
if(strcmp(str,"CHANGE") == 0)
{
scanf("%d %d",&a, &b);
updata(1,1,n,tid[a],b);
}
else
{
scanf("%d %d",&a, &b);
if(strcmp(str,"QMAX") == 0)
{
printf("%d\n",Query2(a,b));
}
else
{
printf("%d\n",Query1(a,b));
}
}
}
}