Problem J. Ka Chang
Output file: standard output
Time limit: 1 seconds
Memory limit: 128 mebibytes
Problem Description
Given a rooted tree ( the root is node 1 ) of n n nodes. Initially, each node has zero point.
Then, you need to handle Q operations. There’re two types:
1 L X: Increase points by X of all nodes whose depth equals L ( the depth of the root is zero ). (x≤10^8 )
2 X: Output sum of all points in the subtree whose root is X.
Input
Just one case.
The first lines contain two integer, N,Q. (N≤10^5 ,Q≤10^5).
The next n-1lines: Each line has two integer a,b, means that node a is the father of node b. It’s guaranteed that the input data forms a rooted tree and node 1 is the root of it.
The next Q lines are queries.
Output
For each query 2, you should output a number means answer.
Sample Input | Sample Output |
---|---|
3 3 1 2 2 3 1 1 1 2 1 2 3 |
1 0 |
先dfs求出字典序。
一个显而易见的事实是,如果某层的节点个数超过,那么这样的层数不超过
n‾√
n
个。
那么,更新时,对于层数少于
n‾√
n
的层,暴力在树状数组上更新;对于节点较多的层,直接利用一个数组记录这一层增加的值。
查询时,我们记录某层所有节点的dfs序,先查询树状数组对答案的贡献;之后,对于每个节点数大于
n‾√
n
的层,记录这层所有节点的dfs序编号,二分出这层有多少节点符合条件。答案就是两部分加起来。
#include <cstdio>
#include <iostream>
#include <string.h>
#include <string>
#include <map>
#include <queue>
#include <deque>
#include <vector>
#include <set>
#include <algorithm>
#include <math.h>
#include <cmath>
#include <stack>
#include <iomanip>
#include <assert.h>
#define pb push_back
#define mem0(a) memset(a,0,sizeof(a))
#define meminf(a) memset(a,0x3f,sizeof(a))
using namespace std;
typedef long long ll;
typedef long double ld;
typedef double db;
typedef pair<int,int> pp;
const int maxn=100005,inf=0x3f3f3f3f;
const ll llinf=0x3f3f3f3f3f3f3f3f;
const ld pi=acos(-1.0L);
int head[maxn],dep[maxn],in[maxn],out[maxn];
ll f[maxn],cnt[maxn],sz[maxn];
bool visit[maxn];
vector<int> v[maxn],l;
int size,num=0,dfn=0;
struct Edge {
int from,to,pre;
};
Edge edge[maxn*2];
void addedge(int from,int to) {
edge[num]=(Edge){from,to,head[from]};
head[from]=num++;
edge[num]=(Edge){to,from,head[to]};
head[to]=num++;
}
void dfs(int now,int step) {
visit[now]=1;
dep[now]=step;
in[now]=out[now]=++dfn;
v[step].pb(dfn);
for (int i=head[now];i!=-1;i=edge[i].pre) {
int to=edge[i].to;
if (!visit[to]) {
dfs(to,step+1);
out[now]=dfn;
}
}
}
int lowbit(int a) {
return (a&(-a));
}
ll getsum(int tt) {
ll sum=0;
for (int t=tt;t;t-=lowbit(t))
sum+=f[t];
return sum;
}
void update(int tt,ll c,int n) {
int t=tt;
for (int t=tt;t<=n;t+=lowbit(t))
f[t]+=c;
}
int main() {
int n,q;
ll x,y,z;
scanf("%d%d",&n,&q);
size=sqrt(n);
num=0;memset(head,-1,sizeof(head));
for (int i=1;i<n;i++) {
scanf("%lld%lld",&x,&y);
addedge(x,y);
}
dfs(1,0);
for (int i=1;i<=n;i++) {
sz[i]=v[i].size();
if (sz[i]>=size) l.pb(i);
}
int m=l.size();
for (int T=1;T<=q;T++) {
scanf("%lld",&x);
if (x==1) {
scanf("%lld%lld",&x,&y);
if (sz[x]>=size) {
cnt[x]+=y;
} else {
for (int i=0;i<sz[x];i++) {
update(v[x][i],y,n);
}
}
} else {
scanf("%lld",&x);
ll ans=getsum(out[x])-getsum(in[x]-1);
for (int i=0;i<m;i++) {
if (l[i]>=dep[x]) {
int L,R;
R=lower_bound(v[l[i]].begin(),v[l[i]].end(),out[x])-v[l[i]].begin();
if (v[l[i]][R]==out[x]) R++;
L=lower_bound(v[l[i]].begin(),v[l[i]].end(),in[x])-v[l[i]].begin();
assert(R>=L);
ans+=cnt[l[i]]*(ll)(R-L);
}
}
printf("%lld\n",ans);
}
}
return 0;
}