思路:
树链剖分模板,推荐博客
(https://www.cnblogs.com/chinhhh/p/7965433.html)
c o d e code code
#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;
long long n, m, r, p;
long long w[100100], head[100100], tot;
long long dep[100100], fa[100100], son[100100], siz[100100], top[100100];
long long we[100100], id[100100], cnt;
struct abc
{
long long sum, l, r, maxx;
}a[100100<<2];
struct node
{
long long to, next;
}b[200100];
void add(long long x, long long y)
{
b[++tot]=(node){y, head[x]};
head[x]=tot;
}
void build(long long k, long long l, long long r)
{
a[k].l=l, a[k].r=r;
if(l==r)
{
a[k].sum=we[l];
a[k].maxx=we[l];
return;
}
long long mid=l+r>>1;
build(k*2, l, mid);
build(k*2+1, mid+1, r);
a[k].sum=(a[k*2].sum+a[k*2+1].sum);
a[k].maxx=max(a[k*2].maxx, a[k*2+1].maxx);
}
void change(long long k, long long l, long long r, long long x, long long y, long long z)
{
if(x<=l&&r<=y)
{
a[k].sum=z;
a[k].maxx=z;
return;
}
long long mid=l+r>>1;
if(x<=mid)
change(k*2, l, mid, x, y, z);
if(y>mid)
change(k*2+1, mid+1, r, x, y, z);
a[k].sum=(a[k*2].sum+a[k*2+1].sum);
a[k].maxx=max(a[k*2].maxx, a[k*2+1].maxx);
}
long long query_sum(long long k, long long l, long long r, long long x, long long y)
{
if(x<=l&&r<=y)
{
return a[k].sum;
}
long long mid=l+r>>1, tmp=0;
if(x<=mid)
tmp+=query_sum(k*2, l, mid, x, y);
if(y>mid)
tmp+=query_sum(k*2+1, mid+1, r, x, y);
a[k].sum=(a[k*2].sum+a[k*2+1].sum);
return tmp;
}
long long query_maxx(long long k, long long l, long long r, long long x, long long y)
{
if(x<=l&&r<=y)
{
return a[k].maxx;
}
long long mid=l+r>>1, tmp=-1e9;
if(x<=mid)
tmp=max(tmp, query_maxx(k*2, l, mid, x, y));
if(y>mid)
tmp=max(tmp, query_maxx(k*2+1, mid+1, r, x, y));
return tmp;
}
void change_tree(long long x, long long y, long long k)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x, y);
change(1, 1, n, id[top[x]], id[x], k);
x=fa[top[x]];
}
if(dep[x]>dep[y])
swap(x, y);
change(1, 1, n, id[x], id[y], k);
}
long long query_tree(long long x, long long y, long long flag)
{
long long ans=0, ansm=-1e9;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x, y);
ans+=query_sum(1, 1, n, id[top[x]], id[x]);
ansm=max(ansm, query_maxx(1, 1, n, id[top[x]], id[x]));
x=fa[top[x]];
}
if(dep[x]>dep[y])
swap(x, y);
ans+=query_sum(1, 1, n, id[x], id[y]);
ansm=max(ansm, query_maxx(1, 1, n, id[x], id[y]));
if(flag==1)
return ans;
else
return ansm;
}
void dfs(long long x, long long f, long long deep)
{
fa[x]=f;
dep[x]=deep;
siz[x]=1;
long long maxson=-1;
for(long long i=head[x]; i; i=b[i].next)
{
long long y=b[i].to;
if(y==f)
continue;
dfs(y, x, deep+1);
siz[x]+=siz[y];
if(siz[y]>maxson)
maxson=siz[y], son[x]=y;
}
}
void dfs1(long long x, long long topf)
{
id[x]=++cnt;
we[cnt]=w[x];
top[x]=topf;
if(!son[x])
return;
dfs1(son[x], topf);
for(long long i=head[x]; i; i=b[i].next)
{
long long y=b[i].to;
if(y==fa[x]||y==son[x])
continue;
dfs1(y, y);
}
}
int main()
{
scanf("%lld", &n);
for(long long i=1; i<n; i++)
{
long long x, y;
scanf("%lld%lld", &x, &y);
add(x, y);
add(y, x);
}
for(long long i=1; i<=n; i++)
scanf("%lld", &w[i]);
dfs(1, 0, 1);
dfs1(1, 1);
build(1, 1, n);
scanf("%lld", &m);
while(m--)
{
string k;
long long x, y, z;
cin>>k;
if(k=="CHANGE")
{
scanf("%lld%lld", &x, &z);
y=x;
change_tree(x, y, z);
}
if(k=="QMAX")
{
scanf("%lld%lld", &x, &y);
printf("%lld\n", query_tree(x, y, 0));
}
if(k=="QSUM")
{
scanf("%lld%lld", &x, &y);
printf("%lld\n", query_tree(x, y, 1));
}
}
return 0;
}