BZOJ 1036: [ZJOI2008]树的统计Count
题目
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
4
1
2
2
10
6
5
6
5
16
题目
树链剖分,(第一次写这个,代码里加了点说明)
题解
#include<cstdio>
#define INF 0x7fffffff
using namespace std;
int n,m,tot,sz,q;
int dep[30005],size[30005],lnk[30005],fa[30005],pos[30005],top[30005];
char s[10];
struct tree
{
int l,r,sum,mx;
} tr[90005];
struct edge
{
int nxt,y;
} e[60005];
int readln()
{
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();}
while ('0'<=ch&&ch<='9') x=x*10+ch-48,ch=getchar();
return x*f;
}
void add(int x,int y) //建边
{
tot++;e[tot].nxt=lnk[x];lnk[x]=tot;e[tot].y=y;
tot++;e[tot].nxt=lnk[y];lnk[y]=tot;e[tot].y=x;
}
int max(int x,int y){return x>y?x:y;}
void sort(int &x,int &y)
{
int t=x;x=y;y=t;
}
void dfs(int x) //以1为根遍历统计以点x为根的子树的大小、点x的深度以及点x的父节点
{
size[x]=1;
for (int i=lnk[x];i;i=e[i].nxt)
{
int y=e[i].y;
if (y==fa[x]) continue;
dep[y]=dep[x]+1;fa[y]=x;
dfs(y);
size[x]+=size[y];
}
}
void dfs1(int x,int to) //将树拆为n段重链,(可以证明有n≤log2(n)条重链),记录每条重链中,深度最小的那个节点,作为该重链的top节点
{
int k=0,y=0;
sz++;pos[x]=sz;
top[x]=to;
for (int i=lnk[x];i;i=e[i].nxt)
{
y=e[i].y;
if (dep[y]>dep[x]&&size[y]>size[k]) k=y;
}
if (k==0) return;
dfs1(k,to);
for (int i=lnk[x];i;i=e[i].nxt)
{
y=e[i].y;
if (dep[y]>dep[x]&&k!=y) dfs1(y,y);
}
}
void build(int l,int r,int rt) //线段树的操作,应该就不用讲了吧......
{
tr[rt].l=l;tr[rt].r=r;
if (l==r) return;
int mid=(l+r)>>1;
build(l,mid,rt<<1);build(mid+1,r,rt<<1|1);
}
void update(int x,int y,int rt)
{
int l=tr[rt].l,r=tr[rt].r;
if (l==r) {tr[rt].sum=tr[rt].mx=y;return;}
int mid=(l+r)>>1;
if (x<=mid) update(x,y,rt<<1); else update(x,y,rt<<1|1);
tr[rt].sum=tr[rt<<1].sum+tr[rt<<1|1].sum;
tr[rt].mx=max(tr[rt<<1].mx,tr[rt<<1|1].mx);
}
int querys(int l,int r,int rt)
{
int ll=tr[rt].l,rr=tr[rt].r;
if (ll==l&&rr==r) return tr[rt].sum;
int mid=(ll+rr)>>1;
if (r<=mid) return querys(l,r,rt<<1);
else if (mid<l) return querys(l,r,rt<<1|1);
else return querys(l,mid,rt<<1)+querys(mid+1,r,rt<<1|1);
}
int querym(int l,int r,int rt)
{
int ll=tr[rt].l,rr=tr[rt].r;
if (ll==l&&rr==r) return tr[rt].mx;
int mid=(ll+rr)>>1;
if (r<=mid) return querym(l,r,rt<<1);
else if (mid<l) return querym(l,r,rt<<1|1);
else return max(querym(l,mid,rt<<1),querym(mid+1,r,rt<<1|1));
}
int solves(int x,int y) //统计sum,因为如果x,y不在同一条重链上,那么就需要不断把top节点较低的那个重链上的sum信息加入总sum中,直到x,y在同一条重链上时,再用线段树统计一遍sum,才算是从x到y路径上所有的点都走了一遍
{
int sum=0;
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) sort(x,y);
sum+=querys(pos[top[x]],pos[x],1);
x=fa[top[x]];
}
if (pos[x]>pos[y]) sort(x,y);
sum+=querys(pos[x],pos[y],1);
return sum;
}
int solvem(int x,int y) //和上面那个操作差不多,只不过改成取最大值而已
{
int mx=-INF;
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) sort(x,y);
mx=max(mx,querym(pos[top[x]],pos[x],1));
x=fa[top[x]];
}
if (pos[x]>pos[y]) sort(x,y);
mx=max(mx,querym(pos[x],pos[y],1));
return mx;
}
int main()
{
n=readln();
for (int i=1;i<n;i++) add(readln(),readln());
dfs(1);dfs1(1,1);
build(1,n,1);
for (int i=1;i<=n;i++) update(pos[i],readln(),1);
q=readln();
while (q--)
{
scanf("%s",s);
int x=readln(),y=readln();
if (s[0]=='C') update(pos[x],y,1);
else if (s[1]=='M') printf("%d\n",solvem(x,y));
else printf("%d\n",solves(x,y));
}
}