http://acm.hdu.edu.cn/showproblem.php?pid=5692
题意,一颗树,从0出发,必须经过某个节点,一个节点只能经过一次,总价值最大获得多少?
思路,基础的DFS序+线段树瞎搞一下子。每个节点处理一下总价值就可以了。那么就是一个裸的区间最大值了。愉快的水一下。
代码:
#include <stdio.h>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
#include <fstream>
using namespace std;
const int MAX=100010;
int n,m,time;
int vis[MAX+10];
vector<int> son[MAX+10];
int val[MAX+10];
long long arr[MAX+10];
int line[MAX+10][2];
struct pr{
long long sum,lazy;
int left,right;
}tr[MAX*3+1000];
inline int ll(int k) {return 2*k;}
inline int rr(int k) {return 2*k+1;}
inline int mid(int kk1,int kk2) {return (kk1+kk2)>>1;}
void pushdown(int k) {
if (tr[k].lazy) {
tr[ll(k)].sum+=tr[k].lazy;tr[rr(k)].sum+=tr[k].lazy;
tr[ll(k)].lazy+=tr[k].lazy;tr[rr(k)].lazy+=tr[k].lazy;
tr[k].lazy=0;
}
}
void build(int k,int s,int t) {
tr[k].left=s;tr[k].right=t;tr[k].lazy=0;
if(s==t) {tr[k].sum=arr[s];return;}
build(ll(k),s,mid(s,t));
build(rr(k),mid(s,t)+1,t);
tr[k].sum=max(tr[ll(k)].sum,tr[rr(k)].sum);
}
void modify(int k,int s,int t,long long x) {
int l=tr[k].left,r=tr[k].right;
if(l==s&&r==t) {
tr[k].lazy+=x;
tr[k].sum+=x;
return ;
}
pushdown(k);
int mi=mid(l,r);
if(t<=mi) modify(ll(k),s,t,x);
else if(s>mi) modify(rr(k),s,t,x);
else {
modify(ll(k),s,mi,x);
modify(rr(k),mi+1,t,x);
}
tr[k].sum=max(tr[ll(k)].sum,tr[rr(k)].sum);
}
long long query(int k,int s,int t) {
int l=tr[k].left,r=tr[k].right;
if(l==s&&r==t) return tr[k].sum;
pushdown(k);
int mi=mid(l,r);
if (t<=mi) return query(ll(k),s,t);
else if(s>mi) return query(rr(k),s,t);
else return max(query(ll(k),s,mi),query(rr(k),mi+1,t));
}
void dfs(int x,long long add){
line[x][0]=++time;
arr[time]=add+val[x];
vis[x]=1;
for(int i=0;i<son[x].size();i++){
if(vis[son[x][i]]==0) dfs(son[x][i],arr[line[x][0]]);
}
line[x][1]=time;
return;
}
int main(){
int t,cas=1;
//fstream fcin,fout;
//fcin.open("in.txt");fout.open("out.txt");
//fcin>>t;
scanf("%d",&t);
while(t--){
scanf("%d%d",&n,&m);
//fcin>>n>>m;
int a,b,c;
memset(vis,0,sizeof(vis));
time=0;
for(int i=0;i<n;i++) son[i].clear();
for(int i=0;i<n-1;i++){
//fcin>>a>>b;
scanf("%d%d",&a,&b);
son[a].push_back(b);son[b].push_back(a);
}
for(int i=0;i<n;i++) scanf("%d",&val[i]);
dfs(0,0);
build(1,1,time);
//fout<<"Case #"<<cas++<<":\n";
printf("Case #%d:\n",cas++);
for(int i=1;i<=m;i++){
//fcin>>a;
scanf("%d",&a);
if(a==0){
//fcin>>b>>c;
scanf("%d%d",&b,&c);
modify(1,line[b][0],line[b][1],c-val[b]);
val[b]=c;
}
else{
//fcin>>b;
scanf("%d",&b);
//fout<<query(1,line[b][0],line[b][1])<<endl;
printf("%lld\n",query(1,line[b][0],line[b][1]));
}
}
}
return 0;
}