BZOJ1036 树的统计Count(树链剖分+线段树)
题目大意
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
解题思路
树链剖分版题
AC代码
#include<bits/stdc++.h>
using namespace std;
const int size=3e4+5;
const int inf=1e9;
int head[size],nxt[2*size];
int to[2*size];
int tot,n;
int p[size];
struct node{
int mx,sums;
node():mx(-inf),sums(0){}
node(int mx,int sums):mx(mx),sums(sums){}
friend node operator+(node a,node b)
{
return node(max(a.mx,b.mx),a.sums+b.sums);
}
}tree[size<<2];
inline int lson(int x) {return x<<1;}
inline int rson(int x) {return x<<1|1;}
void addedge(int u,int v)
{
to[tot]=v;
nxt[tot]=head[u];
head[u]=tot++;
}
int sz[size],son[size],dep[size],dfn[size];
int fa[size],top[size];
int opi[size];
int cnt;
void dfs1(int v,int f)
{
sz[v]=1;
son[v]=0;
fa[v]=f;
dep[v]=dep[f]+1;
for(int i=head[v];i!=-1;i=nxt[i])
{
if(to[i]==f) continue;
dfs1(to[i],v);
sz[v]+=sz[to[i]];
if(sz[to[i]]>sz[son[v]]) son[v]=to[i];
}
}
void dfs2(int v,int f,int k)
{
dfn[v]=++cnt;
opi[cnt]=v;
top[v]=k;
if(son[v]!=0) dfs2(son[v],v,k);
for(int i=head[v];i!=-1;i=nxt[i])
{
if(to[i]==f||to[i]==son[v]) continue;
dfs2(to[i],v,to[i]);
}
}
void build(int id,int l,int r)
{
if(l==r)
{
tree[id].mx=p[opi[l]];
tree[id].sums=p[opi[l]];
return ;
}
int mid=(l+r)>>1;
build(lson(id),l,mid);
build(rson(id),mid+1,r);
tree[id].mx=max(tree[lson(id)].mx,tree[rson(id)].mx);
tree[id].sums=tree[lson(id)].sums+tree[rson(id)].sums;
}
void update(int id,int l,int r,int u,int t)
{
if(l==r&&l==u)
{
tree[id].mx=t;
tree[id].sums=t;
return ;
}
int mid=(l+r)>>1;
if(mid>=u) update(lson(id),l,mid,u,t);
else update(rson(id),mid+1,r,u,t);
tree[id]=tree[lson(id)]+tree[rson(id)];
}
node query(int id,int l,int r,int ql,int qr)
{
if(ql==l&&qr==r)
return tree[id];
int mid=(l+r)/2;
if(mid>=qr) return query(lson(id),l,mid,ql,qr);
else if(mid<ql) return query(rson(id),mid+1,r,ql,qr);
else return query(lson(id),l,mid,ql,mid)+query(rson(id),mid+1,r,mid+1,qr);
}
node qtree(int u,int v)
{
node ans;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);
ans=ans+query(1,1,cnt,dfn[top[u]],dfn[u]);
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
ans=ans+query(1,1,cnt,dfn[v],dfn[u]);
return ans;
}
int main()
{
scanf("%d",&n);
int u,v;
memset(head,-1,sizeof(head));
cnt=0;tot=0;
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
for(int i=1;i<=n;i++) scanf("%d",&p[i]);
dep[0]=0;sz[0]=0;
dfs1(1,0);
dfs2(1,0,1);
build(1,1,cnt);
int q;
scanf("%d",&q);
char op[10];
while(q--)
{
scanf("%s",op);
scanf("%d%d",&u,&v);
if(op[1]=='M')
{
printf("%d\n",qtree(u,v).mx);
}
else if(op[1]=='S')
{
printf("%d\n",qtree(u,v).sums);
}
else
{
update(1,1,cnt,dfn[u],v);
}
}
}