题目链接
三种操作
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
都很典型,没注意权值范围wa了一发。
下面是ac代码:
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#define ll long long
using namespace std;
const int N = 1e5+5;
const ll lnf = 0x3f3f3f3f3f3f3f3f;
int he[N<<1], ne[N<<1], ver[N];
int dep[N], siz[N], fa[N], id[N], son[N], top[N];
int dfn[N];
ll su[N], w[N];
int tot = 1;
struct Node
{
int l, r;
ll mx, sum;
}tr[N<<2];
void add(int x, int y)
{
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
void dfs1(int u, int f)
{
fa[u] = f;
dep[u] = dep[f]+1;
siz[u] = 1;
int mx = -1;
for (int i = he[u]; i; i = ne[i])
{
int y = ver[i];
if (y == f) continue;
dfs1(y, u);
siz[u] += siz[y];
if (siz[y] > mx)
{
mx = siz[y];
son[u] = y;
}
}
}
int cnt;
void dfs2(int u, int t)
{
dfn[u] = ++cnt;
top[u] = t;
w[cnt] = su[u];
if (!son[u])
return;
dfs2(son[u], t);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == fa[u] || v == son[u])
continue;
dfs2(v, v);
}
}
void pushup(int p)
{
tr[p].sum = tr[p<<1].sum + tr[p<<1|1].sum;
tr[p].mx = max(tr[p<<1].mx, tr[p<<1|1].mx);
}
void build(int p, int l, int r)
{
tr[p].l = l; tr[p].r = r;
if (l==r)
{
tr[p].mx = tr[p].sum = w[l];
return;
}
int mid = (l+r) >>1;
build(p<<1, l, mid);
build(p<<1|1, mid+1, r);
pushup(p);
}
void change(int p, int k, ll x)
{
if (tr[p].l == tr[p].r)
{
tr[p].mx = x;
tr[p].sum = x;
return;
}
int mid = (tr[p].l + tr[p].r) >> 1;
if (k <= mid) change(p<<1, k, x);
else change(p<<1|1, k, x);
pushup(p);
}
ll ask_s(int p, int l, int r)
{
if (l <= tr[p].l && tr[p].r <= r)
return tr[p].sum;
int mid = (tr[p].l +tr[p].r) >> 1;
ll val = 0;
if (l <= mid) val += ask_s(p<<1, l, r);
if (r > mid) val += ask_s(p<<1|1, l, r);
return val;
}
ll ask_m(int p, int l, int r)
{
if (l <= tr[p].l && tr[p].r <= r)
return tr[p].mx;
int mid = (tr[p].l +tr[p].r) >> 1;
ll val = -lnf;
if (l <= mid) val = max(val, ask_m(p<<1, l, r));
if (r > mid) val = max(val, ask_m(p<<1|1, l, r));
return val;
}
ll mask_m(int x, int y)
{
ll mx = -lnf;
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]])
swap(x, y);
mx = max(mx, ask_m(1, dfn[top[x]], dfn[x]));
x = fa[top[x]];
}
if (dep[x] > dep[y])
swap(x, y);
return max(mx, ask_m(1, dfn[x], dfn[y]));
}
ll mask_s(int x, int y)
{
ll sum = 0;
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]])
swap(x, y);
sum += ask_s(1, dfn[top[x]], dfn[x]);
x = fa[top[x]];
}
if (dep[x] > dep[y])
swap(x, y);
sum += ask_s(1, dfn[x], dfn[y]);
return sum;
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
}
for (int i = 1; i <= n; i++)
{
scanf("%lld", &su[i]);
}
dfs1(1, 0);
dfs2(1, 1);
build(1, 1, n);
int m;
scanf("%d", &m);
while(m--)
{
char op[16];
scanf("%s", op);
if (op[1] == 'M')
{
int l, r;
scanf("%d%d", &l, &r);
printf("%lld\n", mask_m(l, r));
}
else if (op[1] == 'S')
{
int l, r;
scanf("%d%d", &l, &r);
printf("%lld\n", mask_s(l, r));
}
else
{
int k;
ll x;
scanf("%d%lld", &k, &x);
change(1, dfn[k], x);
}
}
return 0;
}