bzoj1036: [ZJOI2008]树的统计Count
树链剖分裸题,水到觉得发上来不太好。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#define mxn 100010
using namespace std;
int n,m,u,v,s,t=1;
int dep[mxn],f[mxn],siz[mxn],son[mxn],pos[mxn],npos[mxn],top[mxn],a[mxn];
int to[mxn<<1],hd[mxn<<1],lk[mxn],cnt=0;
char ch[5];
void add(int st,int ed)
{to[cnt]=ed,hd[cnt]=lk[st],lk[st]=cnt++;}
void dfs(int k)
{
siz[k]=1;
dep[k]=dep[f[k]]+1;
for(int i=lk[k];i>=0;i=hd[i])
if(to[i]!=f[k])
{
f[to[i]]=k;
dfs(to[i]);
siz[k]+=siz[to[i]];
if(siz[to[i]]>siz[son[k]])
son[k]=to[i];
}
}
void dfss(int k)
{
pos[t]=k,npos[k]=t++;
if(son[k])
top[son[k]]=top[k],dfss(son[k]);
for(int i=lk[k];i>=0;i=hd[i])
if(to[i]!=son[k]&&to[i]!=f[k])
top[to[i]]=to[i],dfss(to[i]);
}
int mx[mxn<<2],sum[mxn<<2],ans;
int ll[mxn<<2],rr[mxn<<2];
void update(int k)
{
sum[k]=sum[k<<1]+sum[k<<1|1];
mx[k]=mx[k<<1];
if(mx[k]<mx[k<<1|1])mx[k]=mx[k<<1|1];
}
void build(int k,int l,int r)
{
ll[k]=l,rr[k]=r;
if(l==r)sum[k]=mx[k]=a[pos[l]];
else
{
build(k<<1,l,(l+r)>>1);
build(k<<1|1,((l+r)>>1)+1,r);
update(k);
}
}
int qs(int k,int l,int r)
{
if(ll[k]==l&&rr[k]==r)
return sum[k];
else
{
int mid=(ll[k]+rr[k])>>1;
if(l>mid)return qs(k<<1|1,l,r);
else if(r<=mid)return qs(k<<1,l,r);
else return qs(k<<1,l,mid)+qs(k<<1|1,mid+1,r);
}
}
int qm(int k,int l,int r)
{
if(ll[k]==l&&rr[k]==r)
return mx[k];
else
{
int mid=(ll[k]+rr[k])>>1;
if(l>mid)return qm(k<<1|1,l,r);
else if(r<=mid)return qm(k<<1,l,r);
else return max(qm(k<<1,l,mid),qm(k<<1|1,mid+1,r));
}
}
void change(int k,int lr,int val)
{
if(ll[k]==rr[k])sum[k]=mx[k]=val;
else
{
int mid=(ll[k]+rr[k])>>1;
if(lr>mid)change(k<<1|1,lr,val);
else change(k<<1,lr,val);
update(k);
}
}
int qsum()
{
ans=0;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])
u^=v^=u^=v;
ans+=qs(1,npos[top[u]],npos[u]);
u=f[top[u]];
}
if(dep[u]<dep[v])u^=v^=u^=v;
ans+=qs(1,npos[v],npos[u]);
return ans;
}
int qmax()
{
ans=-50000;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])
u^=v^=u^=v;
ans=max(ans,qm(1,npos[top[u]],npos[u]));
u=f[top[u]];
}
if(dep[u]<dep[v])u^=v^=u^=v;
ans=max(ans,qm(1,npos[v],npos[u]));
return ans;
}
int main()
{
memset(lk,-1,sizeof(lk));
scanf("%d",&n);
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
dfs(1),dfss(1);
build(1,1,n);
scanf("%d",&m);
while(m--)
{
scanf("\n%s%d%d",ch,&u,&v);
if(ch[1]=='H')change(1,npos[u],v);
else if(ch[1]=='M')printf("%d\n",qmax());
else printf("%d\n",qsum());
}
}