Snacks
题目链接
分析:首先,先求出每个节点的第一次出现时的dfn序,用in数组维护,第二次出现的dfn序,用out数组来维护,然后,区间【in[x],out[x]】就是节点的儿子及其自己,所以可以用线段树来维护这段区间的最大值,及题目所求。
代码:
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <queue>
#include <set>
#include <map>
#include <algorithm>
#include <math.h>
#include <vector>
using namespace std;
typedef long long ll;
const int mod=1e9+7;
const int maxn=2e5+10;
const ll inf=1e17;
//first表示节点第一次出现的位置
//in表示第一次出现的dfn序,out表示第二次出现时的dfn序
int in[maxn],out[maxn],idx,first[maxn];
ll dis[maxn];
ll val[maxn];
int n,m;
//建图
struct Edge{
int to,nex;
}edge[maxn];
int tot,head[maxn];
void init(){
tot=0;
idx=0;
memset (head,-1,sizeof (head));
}
void addedge(int u,int v){
edge[tot]=Edge{v,head[u]};head[u]=tot++;
}
//求dfn序,in,out,first
void dfs(int u,int fa){
in[u]=++idx;
first[idx]=u;
for (int i=head[u];i!=-1;i=edge[i].nex){
int v=edge[i].to;
if (v==fa)continue;
dis[v]=dis[u]+val[v];
dfs(v,u);
}
out[u]=idx;
}
ll sum[maxn<<2],setv[maxn<<2];
void push_up(int rt){
sum[rt]=max(sum[rt*2],sum[rt*2+1]);
}
void push_down(int rt){
if (setv[rt]){
setv[rt*2]+=setv[rt];
setv[rt*2+1]+=setv[rt];
sum[rt*2]+=setv[rt];
sum[rt*2+1]+=setv[rt];
setv[rt]=0;
}
}
void build (int rt,int l,int r){//建树
setv[rt]=0;
if (l==r){
sum[rt]=dis[first[l]];//第一次出现是时的节点的距离
return;
}
int mid=(l+r)/2;
build(rt*2,l,mid);
build(rt*2+1,mid+1,r);
push_up(rt);
}
void update(int rt,int l,int r,int ul,int ur,int val){//更新
if(ul<=l&&ur>=r){
sum[rt]+=val;
setv[rt]+=val;
return;
}
push_down(rt);
int mid=(r+l)/2;
if (ul<=mid)update(rt*2,l,mid,ul,ur,val);
if (ur>mid)update(rt*2+1,mid+1,r,ul,ur,val);
push_up(rt);
}
ll query(int rt,int l,int r,int ql,int qr){//查询
if (ql<=l&&qr>=r){
return sum[rt];
}
push_down(rt);
int mid=(l+r)/2;
ll ans=-inf;
if (ql<=mid)ans=max(ans,query(rt*2,l,mid,ql,qr));
if (qr>mid)ans=max(ans,query(rt*2+1,mid+1,r,ql,qr));
push_up(rt);
return ans;
}
int main()
{
int T;
scanf ("%d",&T);
int cas=1;
while (T--){
init();
scanf ("%d%d",&n,&m);
int u,v;
for (int i=0;i<n-1;i++){
scanf ("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
for (int i=0;i<n;i++){
scanf ("%lld",&val[i]);
}
dis[0]=val[0];
dfs(0,-1);
build(1,1,n);
// for (int i=0;i<n;i++){
// printf ("dis %d = %d\n",i,dis[i]);
// }
int op;
printf ("Case #%d:\n",cas++);
while (m--){
scanf ("%d",&op);
if (op==1){
scanf ("%d",&u);
ll ans=query(1,1,n,in[u],out[u]);
printf ("%lld\n",ans);
}else {
scanf("%d%d",&u,&v);
update(1,1,n,in[u],out[u],v-val[u]);
val[u]=v;
}
}
}
return 0;
}