一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
- I. CHANGE u t : 把结点u的权值改为t
- II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
- III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
输入文件的第一行为一个整数n,表示节点的个数。
接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。
接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。
接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
4
1
2
2
10
6
5
6
5
16
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
题解:这是一道比较裸的树链剖分,鉴于本人脑残,所以调了一个晚上,才AC。
什么是树链剖分呢?
树链就是树上的路径,剖分就是把树分为轻重链,然后用数据结构对链进行维护。
重链顾名思义就是所用重边练成的链(一个节点的所有儿子中子树节点数最多的儿子是该节点的重儿子,重儿子与该节点的连边即为重边),剩下的都是轻链。
剖分后的树有如下性质:
性质1:如果(v,u)为轻边,则size[u] * 2 < size[v];
性质2:从根到某一点的路径上轻链、重链的个数都不大于logn。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define inf 0x7fffffff
#define N 30005
#define M 60005
using namespace std;
int n,m,next[M],point[N],v[M],val[N],tot=0,mi[20],tm[1000005],ts[1000005];
int deep[N],fa[N][20],size[N],sz=0,pos[N],belong[N],tv[N];
void add(int x,int y)//存储边的信息
{
tot++; next[tot]=point[x]; point[x]=tot; v[tot]=y;
tot++; next[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void dfs1(int x,int f,int dep)
{
size[x]=1; deep[x]=dep; //deep表示节点的深度
for (int i=1;i<=14;i++)
{
if (deep[x]-mi[i]<0) break;
fa[x][i]=fa[fa[x][i-1]][i-1];//倍增处理祖先信息
}
for (int i=point[x];i!=0;i=next[i])
if (v[i]!=f)
{
fa[v[i]][0]=x;
dfs1(v[i],x,dep+1);
size[x]+=size[v[i]];//记录以该节点为根的子树的大小
}
}
void dfs2(int x,int chain)
{
int k=0; pos[x]=++sz; //x节点在线段树中的新编号
tv[sz]=val[x];//因为重新编号,所有需要对于存储点权,为之后建立线段树做准备
belong[x]=chain;//这样每条链的起点即为chain
for (int i=point[x];i!=0;i=next[i])
if (size[v[i]]>size[k]&&deep[v[i]]>deep[x])
k=v[i];
if (k==0) return;
dfs2(k,chain);
for (int i=point[x];i!=0;i=next[i])
if (v[i]!=k&&deep[v[i]]>deep[x]) dfs2(v[i],v[i]);
}
int lca(int x,int y)//求最近公共祖先
{
if (deep[x]<deep[y]) swap(x,y);
int ch=deep[x]-deep[y];
for (int i=0;i<=14;i++)
if (ch>>i&1) x=fa[x][i];
if (x==y) return x;
for (int i=14;i>=0;i--)
if (fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void build(int now,int l,int r)//建树
{
if (l==r)
{
tm[now]=ts[now]=tv[l];
return;
}
int mid=(l+r)/2;
build(now<<1,l,mid);
build((now<<1)+1,mid+1,r);
ts[now]=ts[now<<1]+ts[(now<<1)+1];
tm[now]=max(tm[now<<1],tm[(now<<1)+1]);
}
void change(int now,int l,int r,int point,int value)//点修改
{
if (l==r)
{
ts[now]=tm[now]=value;
return;
}
int mid=(l+r)/2;
if (point<=mid)
change(now<<1,l,mid,point,value);
else
change((now<<1)+1,mid+1,r,point,value);
ts[now]=ts[now<<1]+ts[(now<<1)+1];
tm[now]=max(tm[now<<1],tm[(now<<1)+1]);
}
int qjsum(int now,int l,int r,int ll,int rr)//区间求和
{
int sum=0;
if (ll<=l&&rr>=r)
return ts[now];
int mid=(l+r)/2;
if (ll<=mid)
sum+=qjsum(now<<1,l,mid,ll,rr);
if (rr>mid)
sum+=qjsum((now<<1)+1,mid+1,r,ll,rr);
return sum;
}
int qjmax(int now,int l,int r,int ll,int rr)//区间求最值
{
int maxn=-inf;
if(ll<=l&&rr>=r)
return tm[now];
int mid=(l+r)/2;
if (ll<=mid) maxn=max(maxn,qjmax(now<<1,l,mid,ll,rr));
if (rr>mid) maxn=max(maxn,qjmax((now<<1)+1,mid+1,r,ll,rr));
return maxn;
}
int solvesum(int x,int f)
{
int sum=0;
while (belong[x]!=belong[f])//不在一条重链上就将x跳到链首,走一条轻边,如此反复
{
sum+=qjsum(1,1,n,pos[belong[x]],pos[x]);
x=fa[belong[x]][0];
}
sum+=qjsum(1,1,n,pos[f],pos[x]); return sum;
}
int solvemax(int x,int f)
{
int ans=-inf;
while (belong[x]!=belong[f])
{
ans=max(ans,qjmax(1,1,n,pos[belong[x]],pos[x]));
x=fa[belong[x]][0];
}
ans=max(ans,qjmax(1,1,n,pos[f],pos[x])); return ans;
}
int main()
{
scanf("%d",&n);
for (int i=1;i<n;i++)
{
int x,y; scanf("%d%d",&x,&y);
add(x,y);
}
for (int i=1;i<=n;i++) scanf("%d",&val[i]);
mi[0]=1;
for (int i=1;i<=14;i++) mi[i]=mi[i-1]*2;
dfs1(1,0,1);
dfs2(1,1);
scanf("%d",&m); char s[10]; int x,y;
build(1,1,n);
for (int i=1;i<=m;i++)
{
scanf("%s%d%d",s,&x,&y);
if (s[0]=='C')
{
change(1,1,n,pos[x],y),val[x]=y;
continue;
}
int t=lca(x,y); //这里就是一种RE的地方,因为刚开始CHANGE的时候也求了lca,所以。。。。。
if (s[1]=='M') printf("%d\n",max(solvemax(x,t),solvemax(y,t)));
if (s[1]=='S') printf("%d\n",solvesum(x,t)+solvesum(y,t)-val[t]);
}
}