题目传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=4765
这道题已经攒了半年多了。。。因为懒,一直没去写。。。所以今天才把这道题写出来。。。
如果是要维护区间权值和、子树权值和,都可以用线段树/树状数组轻松解决。但是这道题要维护的是子树权值和的区间和,这就比较难搞了。
当需要维护一些看起来很难直接维护的信息时,我们一般会想到分块。于是考虑这样的分块:按编号把每√n个节点划分为一块,维护每一块所有节点的sum值的和,然后再维护每个节点的sum值。单节点的sum可以用树状数组/线段树维护,但为了降低时间复杂度,我们可以用分块维护dfs序的区间和的前缀和,这样的单节点修改复杂度为O(√n),单节点查询复杂度为O(1)。
时间复杂度:修改操作O(√n),查询操作O(√n),总时间复杂度O((n+m)√n)。
具体实现细节:维护第一层分块(即sum值的和)时可以在dfs遍历树时一个数组记录每个节点修改时对每个块的贡献,然后修改时直接统计贡献修改块的值就行了;第二层分块(即单节点的sum)时可以分别维护块的前缀和与每个节点在所在块内的前缀和,查询时把两部分加起来就行了。
另外,答案要开unsigned long long!
代码:
#include<cstdio> #include<cstring> #include<cmath> #include<cstdlib> #include<ctime> #include<algorithm> #include<queue> #include<vector> #define ll long long #define ull unsigned long long #define max(a,b) (a>b?a:b) #define min(a,b) (a<b?a:b) #define inf 0x3f3f3f3f #define mod 1000000007 #define eps 1e-18 inline ll read() { ll tmp=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())tmp=(tmp<<3)+(tmp<<1)+c-'0'; return tmp*f; } using namespace std; struct edge{ int to,nxt; }e[200010]; int fir[100010],l[100010],r[100010],pos[100010]; ull sum1[350],sum2[100010]; int a[100010],tmp[100010]; ull sum[350]; int w[100010][350]; int n,m,size,tot=0,root; void addedge(int x,int y){e[tot].to=y; e[tot].nxt=fir[x]; fir[x]=tot++;} void dfs(int now,int fa) { if(now!=root){ for(int i=0;i*size<n;i++)w[now][i]=w[fa][i]; } ++w[now][now/size]; l[now]=tot; pos[now]=tot++; for(int i=fir[now];~i;i=e[i].nxt) if(e[i].to!=fa)dfs(e[i].to,now); r[now]=tot-1; } void add(int x,int k) { int i,id=pos[x]/size; a[x]+=k; for(i=id;i*size<n;i++)sum1[i]+=k; for(i=pos[x];i<(id+1)*size&&i<n;i++)sum2[i]+=k; for(i=0;i*size<n;i++)sum[i]+=1ll*w[x][i]*k; } ull getsum(int x) { if(x<0)return 0; else return sum2[x]+(x<size?0:sum1[x/size-1]); } ull query(int L,int R) { int i,idL=L/size,idR=R/size; ull ans=0; if(idL==idR){ for(i=L;i<=R;i++)ans+=getsum(r[i])-getsum(l[i]-1); } else{ for(i=idL+1;i<idR;i++)ans+=sum[i]; for(i=L;i<(idL+1)*size&&i<n;i++)ans+=getsum(r[i])-getsum(l[i]-1); for(i=idR*size;i<=R;i++)ans+=getsum(r[i])-getsum(l[i]-1); } return ans; } int main() { int i; n=read(); m=read(); size=(int)sqrt(n); for(i=0;i<n;i++)tmp[i]=read(); for(i=0;i<n;i++)fir[i]=-1; for(i=1;i<=n;i++){ int x=read(),y=read(); if(!x)root=y-1; else addedge(x-1,y-1),addedge(y-1,x-1); } tot=0; dfs(root,-1); for(i=0;i<n;i++)add(i,tmp[i]); for(i=1;i<=m;i++){ int op=read(),x=read(),y=read(); if(op==1)add(x-1,y-a[x-1]); else printf("%llu\n",query(x-1,y-1)); } return 0; }