Description
一棵树,有n个点,每个点都有一个权值,有两种操作
0 a b ,问从节点a到节点b路径上所有点权值和
1 a b,把节点a权值改为b
Input
第一行一个整数T表示用例组数,每组用例第一行为一个整数n表示树节点个数,第二行n个整数表示n个节点的权值1,之后n-1行每行两个整数a和b表示a和b有一条无向边,然后是一个整数m表示操作数,最后m行每行三个整数op u v,op=0表示查询节点a到节点b路径上所有点权值和,op=1表示把节点a权值改为b
Output
对于每次查询,输出查询结果
Sample Input
1
4
10 20 30 40
0 1
1 2
1 3
3
0 2 3
1 1 100
0 2 3
Sample Output
Case 1:
90
170
Solution
树上路径问题,首先树链剖分,然后以每个点的dfs序建线段树,线段树中元素存储区间和,那么问题转化为线段树单点更新和区间查询问题
Code
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
#define maxn 55555
struct Edge
{
int to,next;
}E[2*maxn];
struct Tree
{
int left,right,data;
}T[4*maxn];
int t,n,m,val[maxn],head[maxn],cnt,idx,size[maxn],fa[maxn],son[maxn],dep[maxn],top[maxn],id[maxn],pos[maxn];
void init()
{
cnt=idx=0;
memset(head,-1,sizeof(head));
dep[1]=fa[1]=size[0]=0;
memset(son,0,sizeof(son));
}
void add(int u,int v)
{
E[cnt].to=v;
E[cnt].next=head[u];
head[u]=cnt++;
}
void dfs1(int u)
{
size[u]=1;
for(int i=head[u];~i;i=E[i].next)
{
int v=E[i].to;
if(v!=fa[u])
{
fa[v]=u;
dep[v]=dep[u]+1;
dfs1(v);
size[u]+=size[v];
if(size[son[u]]<size[v]) son[u]=v;
}
}
}
void dfs2(int u,int topu)
{
top[u]=topu;
id[u]=++idx;
pos[idx]=u;
if(son[u]) dfs2(son[u],top[u]);
for(int i=head[u];~i;i=E[i].next)
{
int v=E[i].to;
if(v!=fa[u]&&v!=son[u]) dfs2(v,v);
}
}
void push_up(int t)
{
T[t].data=T[2*t].data+T[2*t+1].data;
}
void build(int l,int r,int t)
{
T[t].left=l;
T[t].right=r;
T[t].data=0;
if(l==r)
{
T[t].data=val[pos[l]];
return ;
}
int mid=(l+r)>>1;
build(l,mid,2*t);
build(mid+1,r,2*t+1);
push_up(t);
}
void update(int x,int v,int t)
{
if(T[t].left==x&&T[t].right==x)
{
T[t].data=v;
return ;
}
int mid=(T[t].left+T[t].right)>>1;
if(x<=mid)update(x,v,2*t);
else update(x,v,2*t+1);
push_up(t);
}
int query(int l,int r,int t)
{
if(T[t].left==l&&T[t].right==r) return T[t].data;
int ans=0;
if(r<=T[2*t].right) return query(l,r,2*t);
else if(l>=T[2*t+1].left) return query(l,r,2*t+1);
return query(l,T[2*t].right,2*t)+query(T[2*t+1].left,r,2*t+1);
}
int getsum(int u,int v)
{
int top1=top[u],top2=top[v],ans=0;
while(top1!=top2)
{
if(dep[top1]<dep[top2])
{
swap(top1,top2);
swap(u,v);
}
ans+=query(id[top1],id[u],1);
u=fa[top1];
top1=top[u];
}
if(dep[u]>dep[v]) swap(u,v);
ans+=query(id[u],id[v],1);
return ans;
}
int main()
{
scanf("%d",&t);
int res=1;
while(t--)
{
printf("Case %d:\n",res++);
init();
int u,v,w,op;
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&val[i]);
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
u++,v++;
add(u,v),add(v,u);
}
dfs1(1);
dfs2(1,1);
build(1,n,1);
scanf("%d",&m);
while(m--)
{
scanf("%d",&op);
if(op)
{
scanf("%d%d",&u,&w);
update(id[++u],w,1);
}
else
{
scanf("%d%d",&u,&v);
printf("%d\n",getsum(++u,++v));
}
}
}
return 0;
}