树上求和
链接:https://ac.nowcoder.com/acm/contest/6290/D
来源:牛客网
时间限制:C/C++ 1秒,其他语言2秒
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld
题目描述
给你一棵根为1的有N个节点的树,以及Q次操作。
每次操作诸如:
1 x y:将节点x所在的子树的所有节点的权值加上y
2 x:询问x所在子树的所有节点的权值的平方和,答案模23333后输出
输入描述:
第一行两个整数N,Q
第二行N个整数,第i个表示节点i的初始权值
接下来N-1行每行两个整数u,v,表示u和v之间存在一条树边
接下来Q行每行一个操作,格式如题目描述
输出描述:
对于每个询问操作,输出一行一个整数,表示答案在模23333后的结果
示例1
输入
复制
5 5
0 0 0 0 0
1 2
1 3
3 4
3 5
1 1 3
1 3 7
1 4 5
1 5 6
2 1
输出
复制
599
备注:
- 数据范围
一共有10个测试点,对于第i个测试点保证,N=10000 x i
对于100 %100%的数据,有1 ≤ N,Q,y ≤ 105,1 ≤ x ≤ N - 注
平方和的意思是:a2+b2+c2
(a+b+c)^2是和的平方
题解:dfs序处理出子树的区间,然后建线段树维护求和即可。
区间修改时,利用 (a+b)^2 = a^2 + b^2 + 2ab;计算公式即可。
dfs序
int tim=0;
void dfs(int x,int y) {
L[x]=++tim;
b[tim]=a[x];
for(int i=0; i<G[x].size(); ++i) {
if(G[x][i]!=y) {
dfs(G[x][i],x);
}
}
R[x]=tim;
}
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+500;
const ll mod=23333;
int n,m;
ll a[N],b[N];
vector<int>G[N];
int L[N],R[N];
struct Tree {
int l,r;
ll val,Val,tag;
} t[N*4];
int tim=0;
void dfs(int x,int y) {
L[x]=++tim;
b[tim]=a[x];
for(int i=0; i<G[x].size(); ++i) {
if(G[x][i]!=y) {
dfs(G[x][i],x);
}
}
R[x]=tim;
}
void updata(int p) {
t[p].val=(t[p*2].val+t[p*2+1].val)%mod;
t[p].Val=(t[p*2].Val+t[p*2+1].Val)%mod;
return;
}
void down(int p) {
if(t[p].tag) {
t[p*2].Val=(t[p*2].Val+(1ll*(t[p*2].r-t[p*2].l+1)*t[p].tag%mod*t[p].tag)%mod+2*t[p*2].val*t[p].tag%mod)%mod;
t[p*2].val=(t[p*2].val+ 1ll*(t[p*2].r-t[p*2].l+1)*t[p].tag%mod)%mod;
t[p*2].tag=(t[p*2].tag+t[p].tag)%mod;
t[p*2+1].Val=(t[p*2+1].Val+(1ll*(t[p*2+1].r-t[p*2+1].l+1)*t[p].tag%mod*t[p].tag)%mod+2*t[p*2+1].val*t[p].tag%mod)%mod;
t[p*2+1].val=(t[p*2+1].val+1ll*(t[p*2+1].r-t[p*2+1].l+1)*t[p].tag%mod)%mod;
t[p*2+1].tag=(t[p*2+1].tag+t[p].tag)%mod;
t[p].tag=0;
}
return;
}
void build(int p,int l,int r) {
t[p].l=l,t[p].r=r;
if(l==r) {
t[p].val=1ll*b[l]%mod;
t[p].Val=1ll*b[l]*b[l]%mod;
return;
}
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
updata(p);
}
void change(int p,int L,int R,ll x) {
if(L<=t[p].l&&t[p].r<=R) {
t[p].Val=(t[p].Val+1ll*((t[p].r-t[p].l+1)*x*x)%mod+2*t[p].val*x%mod)%mod;
t[p].val=(t[p].val+1ll*(t[p].r-t[p].l+1)*x%mod)%mod;
t[p].tag=(t[p].tag+1ll*x)%mod;
return;
}
down(p);
int mid=(t[p].l+t[p].r)/2;
if(L<=mid)change(p*2,L,R,x);
if(R>mid)change(p*2+1,L,R,x);
updata(p);
}
ll ask(int p,int L,int R) {
if(L<=t[p].l&&t[p].r<=R) {
return t[p].Val;
}
down(p);
int mid=(t[p].l+t[p].r)/2;
ll ans=0;
if(L<=mid)ans=(ans+ask(p*2,L,R))%mod;
if(R>mid)ans=(ans+ask(p*2+1,L,R))%mod;
return ans;
}
int main() {
scanf("%d %d",&n,&m);
for(int i=1; i<=n; ++i) {
scanf("%lld",&a[i]);
}
for(int i=1; i<n; ++i) {
int x,y;
scanf("%d %d",&x,&y);
G[x].push_back(y);
G[y].push_back(x);
}
dfs(1,0);
build(1,1,n);
while(m--) {
int z;
scanf("%d",&z);
if(z==1) {
int x,y;
scanf("%d %d",&x,&y);
change(1,L[x],R[x],1ll*y);
} else {
int x;
scanf("%d",&x);
printf("%lld\n",ask(1,L[x],R[x]));
}
}
return 0;
}