4034: [HAOI2015]树上操作
Time Limit: 10 Sec Memory Limit: 256 MB
Submit: 4844 Solved: 1550
[Submit][Status][Discuss]
Description
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
Input
第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1
行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。再接下来 M 行,每行分别表示一次操作。其中
第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。
Output
对于每个询问操作,输出该询问的答案。答案之间用换行隔开。
Sample Input
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
Sample Output
6
9
13
HINT
对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。
链剖真的是板一样的东西啊qvq求一发dfs序将要超时的树上操作转化为区间处理,然后就可以愉快地用线段树维护了。
根据dfs序处理的不一样,这道题也有不一样的做法。
法1:
如果要查询的点u,它的top为1,则证明它到1的路径上经过点的编号是连续的一段,可以直接用区间求和。如果不是连续的一段,那么我们就可以让它沿着父亲一直跳,直到跳到top为1的时候。
而每次它到它的top都是连续的一段,且不断逼近根,所以每次就跳到当前top的爸爸,然后又求它到上一个祖先的距离。
/**************************************************************
Problem: 4034
User: LaLaLa112138
Language: C++
Result: Accepted
Time:2552 ms
Memory:34032 kb
****************************************************************/
#include<cstdio>
#include<cstring>
#include<iostream>
#define ms(x,y) memset(x,y,sizeof(x))
#define ll long long
using namespace std;
const int N = 400010;
int n,m;
int a[N],a1[N];
inline int Max(int a,int b){
return a>b?a:b;
}
struct node{
int pre,v;
}edge[N];
int num=0,head[N];
void addedge(int from,int to){
num++;
edge[num].pre=head[from];
edge[num].v=to;
head[from]=num;
}
int fa[N],dep[N],son[N],siz[N];
void dfs1(int u,int f,int d){
dep[u]=d,fa[u]=f,siz[u]=1;
for(int i=head[u];i;i=edge[i].pre){
int v=edge[i].v;
if(v==f) continue;
dfs1(v,u,d+1);
siz[u]+=siz[v];
if(son[u]==-1||siz[v]>siz[son[u]]){
son[u]=v;
}
}
}
int top[N],in[N],out[N],seq[N];
int indx=0;
void dfs2(int u,int tp){
indx++;
in[u]=out[u]=indx,seq[indx]=u;
a1[in[u]]=a[u],top[u]=tp;
if(son[u]==-1) return ;
dfs2(son[u],tp);
for(int i=head[u];i;i=edge[i].pre){
int v=edge[i].v;
if(v==son[u]||v==fa[u]) continue;
dfs2(v,v);
}
out[u]=indx;
}
struct Node{
ll sum,flag;
Node *ls,*rs;
void update(){
sum=ls->sum+rs->sum;
}
void pushdown(int l,int r){
if(flag){
int mid=(l+r)>>1;
ls->flag+=flag;
ls->sum+=(long long)(mid-l+1)*flag;
rs->flag+=flag;
rs->sum+=(long long)(r-mid)*flag;
flag=0;
}
}
}*root,pool[N],*tail=pool;
Node *build(int l,int r){
Node *bt=++tail;
if(l==r){
bt->sum=a1[l];
bt->flag=0;
}
else{
int mid=(l+r)>>1;
bt->ls=build(l,mid);
bt->rs=build(mid+1,r);
bt->update();
}
return bt;
}
void modify(Node *bt,int l,int r,int pos,int val,int delta){
if(pos<=l&&val>=r){
bt->sum+=(long long)(r-l+1)*delta;
bt->flag += delta;
return ;
}
int mid=(l+r)>>1;
bt->pushdown(l,r);
if(pos<=mid) modify(bt->ls,l,mid,pos,val,delta);
if(val>mid) modify(bt->rs,mid+1,r,pos,val,delta);
bt->update();
}
ll query1(Node *bt,int l,int r,int pos,int val){
if(pos<=l&&val>=r){
return bt->sum;
}
int mid=(l+r)>>1;
bt->pushdown(l,r);
ll ans=0;
if(pos<=mid) ans+=query1(bt->ls,l,mid,pos,val);
if(val>mid) ans+=query1(bt->rs,mid+1,r,pos,val);
bt->update();
return ans;
}
ll query2(int u,int v){
int f1=top[u],f2=top[v];
ll sum=0;
while(f1!=f2){
if(dep[f1]<dep[f2]) swap(f1,f2),swap(u,v);
sum+=query1(root,1,indx,in[f1],in[u]);
u=fa[f1];
f1=top[u];
}
if(dep[u]>dep[v]) swap(u,v);
sum+=query1(root,1,indx,in[u],in[v]);
return sum;
}
int main(){
ms(son,-1);
scanf("%d%d",&n,&m);
for(register int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(register int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);addedge(v,u);
}
dfs1(1,-1,0);dfs2(1,1);
root=build(1,indx);
// for(int i=1;i<=indx;i++)
// printf("%d\n",query1(root,1,indx,in[i],out[i]));
while(m--){
int x;
scanf("%d",&x);
if(x==1){
int u,delta;
scanf("%d%d",&u,&delta);
modify(root,1,indx,in[u],in[u],delta);
}
else if(x==2){
int u,delta;
scanf("%d%d",&u,&delta);
modify(root,1,indx,in[u],out[u],delta);
}
else{
int u;
scanf("%d",&u);
printf("%lld\n",query2(1,u));
}
}
return 0;
}
方法2:
时间戳一直加,即最后总的时间戳为n<<1
对每个节点,进的时候存正值,出的时候存负值,同时记录下这个修改这个点的值的时候应该加还是减,记得统计一段区间的真实节点和要减的节点的和。
易WA点:
不能直接用a[l]的值为正为负来判断这个节点应加还是减,如果a[l]为0就gg了。(毕竟是WA在这上的人)
/**************************************************************
Problem: 4034
User: LaLaLa112138
Language: C++
Result: Accepted
Time:2848 ms
Memory:69972 kb
****************************************************************/
#include<cstdio>
#include<cstring>
#include<iostream>
#define ms(x,y) memset(x,y,sizeof(x))
#define ll long long
using namespace std;
const int N = 400010<<1;
int n,m;
int io[N],a[N],a1[N];
inline int Max(int a,int b){
return a>b?a:b;
}
struct node{
int pre,v;
}edge[N];
int num=0,head[N];
void addedge(int from,int to){
num++;
edge[num].pre=head[from];
edge[num].v=to;
head[from]=num;
}
int fa[N],dep[N],son[N],siz[N];
void dfs1(int u,int f,int d){
dep[u]=d,fa[u]=f,siz[u]=1;
for(int i=head[u];i;i=edge[i].pre){
int v=edge[i].v;
if(v==f) continue;
dfs1(v,u,d+1);
siz[u]+=siz[v];
if(son[u]==-1||siz[v]>siz[son[u]]){
son[u]=v;
}
}
}
int top[N],in[N],out[N],seq[N];
int indx=0;
void dfs2(int u,int tp){
indx++;
in[u]=indx,seq[indx]=u,io[indx]=1;
a1[in[u]]=a[u],top[u]=tp;
if(son[u]==-1){
indx++;
out[u]=indx,io[indx]=-1;
a1[out[u]]=-a[u];
return ;
}
dfs2(son[u],tp);
for(int i=head[u];i;i=edge[i].pre){
int v=edge[i].v;
if(v==son[u]||v==fa[u]) continue;
dfs2(v,v);
}
indx++;
out[u]=indx,io[indx]=-1;
a1[out[u]]=-a[u];
}
struct Node{
int h;
ll sum,flag;
Node *ls,*rs;
void update(){
sum=ls->sum+rs->sum;
}
void pushdown(int l,int r){
if(flag){
int x=ls->h;
ls->flag+=flag;
ls->sum+=(long long)x*flag;
x=rs->h;
rs->flag+=flag;
rs->sum+=(long long)x*flag;
flag=0;
}
}
}*root,pool[N],*tail=pool;
Node *build(int l,int r){
Node *bt=++tail;
if(l==r){
bt->sum=a1[l];
if(io[l]>0) bt->h=1;
else bt->h=-1;
bt->flag=0;
}
else{
int mid=(l+r)>>1;
bt->ls=build(l,mid);
bt->rs=build(mid+1,r);
bt->update();
bt->h=bt->ls->h+bt->rs->h;
}
return bt;
}
void modify(Node *bt,int l,int r,int pos,int val,int delta){
if(pos<=l&&val>=r){
int x=bt->h;
bt->sum+=(long long)x*delta;
bt->flag += delta;
return ;
}
int mid=(l+r)>>1;
bt->pushdown(l,r);
if(pos<=mid) modify(bt->ls,l,mid,pos,val,delta);
if(val>mid) modify(bt->rs,mid+1,r,pos,val,delta);
bt->update();
}
ll query1(Node *bt,int l,int r,int pos,int val){
if(pos<=l&&val>=r){
return bt->sum;
}
int mid=(l+r)>>1;
bt->pushdown(l,r);
ll ans=0;
if(pos<=mid) ans+=query1(bt->ls,l,mid,pos,val);
if(val>mid) ans+=query1(bt->rs,mid+1,r,pos,val);
return ans;
}
int main(){
ms(son,-1);
scanf("%d%d",&n,&m);
for(register int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(register int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v);addedge(v,u);
}
dfs1(1,-1,0);dfs2(1,1);
root=build(1,indx);
//for(int i=1;i<=n;i++)
// printf("%d %d\n",in[i],out[i]);
// printf("%d\n",query1(root,1,indx,in[i],out[i]));
while(m--){
int x;
scanf("%d",&x);
if(x==1){
int u,delta;
scanf("%d%d",&u,&delta);
modify(root,1,indx,in[u],in[u],delta);
modify(root,1,indx,out[u],out[u],delta);
}
else if(x==2){
int u,delta;
scanf("%d%d",&u,&delta);
modify(root,1,indx,in[u],out[u],delta);
}
else{
int u;
scanf("%d",&u);
printf("%lld\n",query1(root,1,indx,1,in[u]));
}
}
return 0;
}