时间限制:3秒 内存限制:128M
【问题描述】
给定一棵n个节点的带权树,节点编号为1到n,以root为根,设sum[p]表示以点p为根的这棵子树中所有节点的权值和。支持下列两种操作:
1 给定两个整数u,v,修改点u的权值为v。
2 给定两个整数l,r,计算sum[l]+sum[l+1]+….+sum[r-1]+sum[r]
请设计一种算法,尽快实现上面的操作。
【输入格式】
第一行两个整数n,m,表示树的节点数与操作次数。
接下来一行n个整数,第i个整数di表示点i的初始权值。
接下来n行每行两个整数ai,bi,表示一条树上的边,若ai=0则说明bi是根。
接下来m行每行三个整数,第一个整数op表示操作类型。
若op=1则接下来两个整数u,v表示将点u的权值修改为v。
若op=2则接下来两个整数l,r表示询问。
【输出格式】
对每个操作类型2输出一行一个整数表示答案。
【输入样例】
6 4
0 0 3 4 0 1
0 1
1 2
2 3
2 4
3 5
5 6
2 1 2
1 1 1
2 3 6
2 3 5
【输出样例】
16
10
9
【样例解释】
【数据范围】
【来源】
BZOJ4765
这道题十分坑,一开始打了个半暴力,询问用线段树,修改直接暴力爬树,结果还有50分,后来用分块表+线段树还是只有70,最后线段树换成树状数组终于满了。(卡常数我也是醉了)
做题时我们先统计每个点在每个块中出现了几次,这个可以在dfs时统计每个块里有多少它的祖先(详见代码注释),然后循环加一下就好了。
在统计次数的同时我们生成树状数组并求每个块的总值,并生成一个按dfs序的序列,来建立分块表(这样每棵子树的点就是连续的一个线段)。
对于每次修改我们在每个块里增加它变化的值(减小就加负数)乘以出现的次数,然后改变它在树状数组上的值已经本身的值。
对于每次询问,块内的直接加就好,单独的,每次求这个点代表的线段的总值就好了。(这也是树状数组存在的意义)
详细代码如下:
#include<cstdlib>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
using namespace std;
typedef unsigned long long ll;
const int maxn=100005;
struct edge
{
int u,v,next;
}e[maxn*2];
int n,m;
int size,num,ss[405][maxn],t[maxn],b[maxn]={0};
int root,cnt=-1,Ll[405],rr[405];
ll bit[maxn*2],w[maxn*2],ww[405],sum[maxn]={0},a[maxn];
int lt[maxn],rt[maxn],cur[maxn],f[maxn]={0},cnt2=0,time=0;
int lowbit(int x){return x&(-x);}
void in(int i,ll x)
{
for(int k=i;k<=n;k+=lowbit(k)) bit[k]+=x;
}
ll find(int j){
ll sum=0;
for(int k=j;k>0;k-=lowbit(k)) sum+=bit[k];
return sum;
}
ll read(){
ll x=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9')
{
x=x*10+ch-'0';
ch=getchar();
}
return x;
}
void add(int u,int v){
e[++cnt2]=(edge){u,v,f[u]};f[u]=cnt2;
e[++cnt2]=(edge){v,u,f[v]};f[v]=cnt2;
}
void dfs(int i,int fa){
cur[++time]=i;lt[i]=time;
b[t[i]]++;sum[i]=a[i];//标记t[i]这个块中有一个点出现
for(int k=1;k<=num;k++)
ss[k][i]=b[k],ww[k]+=b[k]*a[i];//所有还没有删除的出现了的点一定是他的祖先
in(lt[i],a[i]);
for(int k=f[i];k;k=e[k].next)
{
int j=e[k].v;
if(j==fa) continue;
dfs(j,i);
sum[i]+=sum[j];
}
rt[i]=time;
b[t[i]]--;//遍历完了这个点的所以子孙,删除这个点
}
void init()
{
n=read();m=read();
int x,y,root2;
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<=n;i++)
{
x=read();y=read();
if(x==0) root2=y;
else add(x,y);
}
size=sqrt(n);
num=(n-1)/size+1;
for(int i=1;i<=n;i++)
{
t[i]=(i-1)/size+1;
if(t[i-1]!=t[i]) rr[t[i-1]]=i-1,Ll[t[i]]=i;
}
rr[t[n]]=n;
dfs(root2,-1);
}
char o[55];
void out(ll x)
{
int tt=0;
if(!x) putchar('0');
while(x) o[++tt]=x%10,x/=10;
while(tt)
{
putchar(o[tt]+'0');
tt--;
}
putchar('\n');
}
int main()
{
//freopen("in.txt","r",stdin);
//freopen("out.txt","w",stdout);
init();
int id,x,y;
ll u;
int ttt=0;
while(m--)
{
id=read();
if(id==1)
{
x=read();y=read();
u=y-a[x];a[x]=y;
for(int k=1;k<=num;k++)
ww[k]+=ss[k][x]*u;
in(lt[x],u);
}
if(id==2)
{
x=read();y=read();
ll ans=0;
for(int k=1;k<=num;k++) if(Ll[k]>=x&&rr[k]<=y)
ans+=ww[k];
int tt;
if(x>Ll[t[x]])
{
tt=t[x]*size;
while(x<=tt&&x<=y)
{
ans+=find(rt[x])-find(lt[x]-1);
x++;
}
}
if(y<rr[t[y]])
{
tt=(t[y]-1)*size+1;
while(y>=tt&&y>=x)
{
ans+=find(rt[y])-find(lt[y]-1);
y--;
}
}
out(ans);
}
}
return 0;
}