题意:
就是给你一个以1为跟节点的树,m次操作,一种操作是让x所在的子树所有节点的权值都加上y,另一种是查询x所在子树所有节点的权值平方和。同时对23333取模。
思考:
一看是树上操作区间就知道是dfs序了,然后权值的平方和考过很多次了,就是手推一下公式就看出来了。(a+c)(a+c)+(b+c)(b+c)==aa+bb+2*(a+b)+2cc。所以还要维护一下区间和就行了。值得注意的呢是,如何把树的区间转化为线段树的区间,其实只要按每个节点的左边进行记录记录编号就行了。好比线段树的第一个区间是1到n,那么建树的时候就看看id[1]是多少就了。然后update和query操作都是一样的,仅仅就是建树的时候不怎么一样。
以前做过一道dfs序+树状数组的题:求和。可以结合着看看。
代码:
struct Node{
int L,R;
int sum;
int ans;
int laz;
}t[4*N];
int T,n,m,k;
int va[N];
int l[N],r[N],id[N],cnt;
vector<int > e[N];
void dfs(int now,int p)
{
l[now] = ++cnt;
id[cnt] = now;
for(auto spot:e[now])
{
if(spot==p) continue;
dfs(spot,now);
}
r[now] = cnt;
}
void pushup(int node)
{
t[node].sum = (t[node_l].sum+t[node_r].sum)%mod;
t[node].ans = (t[node_l].ans+t[node_r].ans)%mod;
}
void pushdown(int node)
{
int laz = t[node].laz;
if(laz)
{
t[node_l].laz = (t[node_l].laz+laz)%mod;
t[node_l].ans = (t[node_l].ans+2*laz%mod*t[node_l].sum%mod+laz*laz%mod*(t[node_l].R-t[node_l].L+1)%mod)%mod;
t[node_l].sum = (t[node_l].sum+laz*(t[node_l].R-t[node_l].L+1)%mod)%mod;
t[node_r].laz = (t[node_r].laz+laz)%mod;
t[node_r].ans = (t[node_r].ans+2*laz%mod*t[node_r].sum%mod+laz*laz%mod*(t[node_r].R-t[node_r].L+1)%mod)%mod;
t[node_r].sum = (t[node_r].sum+laz*(t[node_r].R-t[node_r].L+1)%mod)%mod;
t[node].laz = 0;
}
}
void build(int node,int l,int r)
{
t[node].L= l,t[node].R = r;
if(l==r)
{
t[node].sum = va[id[l]]%mod;
t[node].ans = va[id[l]]*va[id[l]]%mod;
return ;
}
int mid = (l+r)>>1;
build(node_l,l,mid);build(node_r,mid+1,r);
pushup(node);
}
void update(int node,int l,int r,int value)
{
if(t[node].L>=l&&t[node].R<=r)
{
t[node].laz = (t[node].laz+value)%mod;
t[node].ans = (t[node].ans+2*value%mod*t[node].sum%mod+value*value%mod*(t[node].R-t[node].L+1)%mod)%mod;
t[node].sum = (t[node].sum+value*(t[node].R-t[node].L+1)%mod)%mod;
return ;
}
pushdown(node);
int mid = (t[node].L+t[node].R)>>1;
if(l<=mid) update(node_l,l,r,value);
if(r>mid) update(node_r,l,r,value);
pushup(node);
}
int query(int node,int l,int r)
{
if(t[node].L>=l&&t[node].R<=r) return t[node].ans%mod;
pushdown(node);
int mid = (t[node].L+t[node].R)>>1;
if(r<=mid) return query(node_l,l,r)%mod;
else if(l>mid) return query(node_r,l,r)%mod;
else return (query(node_l,l,mid)+query(node_r,mid+1,r))%mod;
}
signed main()
{
IOS;
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>va[i];
for(int i=1;i<n;i++)
{
int a,b;
cin>>a>>b;
e[a].pb(b);
e[b].pb(a);
}
dfs(1,0);
build(1,1,n);
while(m--)
{
int op,a,b;
cin>>op>>a;
if(op==1)
{
cin>>b;
update(1,l[a],r[a],b);
}
else cout<<query(1,l[a],r[a])%mod<<"\n";
}
return 0;
}
总结:
多多积累经验呀。