一题非常巧妙的DS题。
假设1为根。
首先可以发现,把这条连从图中抠掉,那么图中会剩下几个联通快:
S
1
,
S
2
,
S
3
.
.
S
k
S_1,S_2,S_3..S_k
S1,S2,S3..Sk则
A
n
s
=
(
∑
i
=
1
n
a
i
)
2
−
∑
i
=
1
k
(
s
i
2
)
Ans=(\sum_{i=1}^n a_i)^2-\sum_{i=1}^k(s_i^2)
Ans=(∑i=1nai)2−∑i=1k(si2)。
以下所说的答案为
∑
i
=
1
k
(
s
i
2
)
\sum_{i=1}^k(s_i^2)
∑i=1k(si2).
这题的难点就在于怎样维护
∑
i
=
1
k
(
s
i
2
)
\sum_{i=1}^k(s_i^2)
∑i=1k(si2)?
一个不难想的思路就是对于一个节点
i
i
i,维护
Z
i
=
∑
j
∈
s
o
n
(
i
)
(
s
u
m
j
2
)
Z_i=\sum _{j\in son(i)} (sum_j^2)
Zi=∑j∈son(i)(sumj2)其中
s
u
m
j
sum_j
sumj表示已j为根的子树的权值总和。
对于询问
u
,
v
u,v
u,v,设他们的
L
C
A
=
l
LCA=l
LCA=l,则对于一侧
u
,
l
u,l
u,l,设
u
−
v
u-v
u−v的路径的点集为
S
S
S,答案
=
∑
j
∈
S
Z
j
+
(
s
u
m
1
−
s
u
m
l
)
2
−
∑
j
∈
S
,
j
≠
l
(
s
u
m
j
2
)
=\sum_{j\in S} Z_j+(sum_1-sum_l)^2-\sum_{j\in S,j\neq l}(sum_j^2)
=∑j∈SZj+(sum1−suml)2−∑j∈S,j=l(sumj2)。一看就是树链剖分+BIT的板子。
但是!这只是对于询问操作,那怎么修改呢?
一看就可以发现
s
u
m
sum
sum的修改树链剖分+BIT就足够了。但是平方的和,就不是那么显然了。
所以,此路不通!
不过以上的思路也不是一点没用,我们仔细分析一下它的瓶颈在哪?
就是我们在计算一条路径上的
∑
w
j
\sum w_j
∑wj的时候会算上当前路径上的
s
u
m
j
2
sum_j^2
sumj2,这也是为什么前面要减去
∑
j
∈
S
,
j
≠
l
(
s
u
m
j
2
)
\sum_{j\in S,j\neq l}(sum_j^2)
∑j∈S,j=l(sumj2)。在脑中脑补一下就可以发现其实我们不需要维护所有儿子的
s
u
m
2
sum^2
sum2只需要是轻儿子!
HLD 有一个非常优美的性质:当只维护轻儿子的信息时,修改可以沿着树链往上跳的时候直接修改,而计算答案时只需要在加上重儿子的贡献。
代码明天补。调了一下午才发现是多测。。。
/*
{By GWj
*/
#pragma GCC optimize(2)
#include<iostream>
#include<algorithm>
#include<assert.h>
#include<cstring>
#define rb(a,b,c) for(int a=b;a<=c;++a)
#define rl(a,b,c) for(int a=b;a>=c;--a)
#define LL long long
#define IT iterator
#define PB push_back
#define II(a,b) make_pair(a,b)
#define FIR first
#define SEC second
#define FREO freopen("check.out","w",stdout)
#define rep(a,b) for(int a=0;a<b;++a)
#define SRAND mt19937 rng(chrono::steady_clock::now().time_since_epoch().count())
#define random(a) rng()%a
#define ALL(a) a.begin(),a.end()
#define POB pop_back
#define ff fflush(stdout)
#define fastio ios::sync_with_stdio(false)
#define R(a) cin>>a
#define R2(a,b) cin>>a>>b
#define check_min(a,b) a=min(a,b)
#define check_max(a,b) a=max(a,b)
#define int LL
using namespace std;
const int INF=0x3f3f3f3f;
typedef pair<int,int> mp;
/*}
*/
const int MAXN=1e5+20;
const int MOD=1e9+7;
int w[MAXN],top[MAXN],heavy[MAXN],depth[MAXN],siz[MAXN],dfn[MAXN],fa[MAXN][19],n,q;
struct BIT{
LL bit[MAXN+10]={0};
LL sum(int i){
LL s=0;
while(i>0){
s+=bit[i];
s+=MOD;
s%=MOD;
i-=i&(-i);
}
return s;
}
void add(int i,LL x=1){
x%=MOD;
while(i<=MAXN){
bit[i]+=x;
bit[i]+=MOD;
bit[i]%=MOD;
i+=i&(-i);
}
}
LL query(int l,int r){
return (sum(r)-sum(l-1)+MOD)%MOD;
}
}sum,sq_sum;
vector<int> g[MAXN];
void dfs1(int now,int pre,int deep=1){
// cout<<now<<endl;
depth[now]=deep;
fa[now][0]=pre;
rb(i,1,18)
fa[now][i]=fa[fa[now][i-1]][i-1];
siz[now]=1;
for(int it:g[now]){
if(it!=pre){
dfs1(it,now,deep+1);
siz[now]+=siz[it];
}
}
for(auto it:g[now]){
if(it!=pre){
if(siz[it]*2>=siz[now]){
heavy[now]=it;
}
}
}
}
int jump(int now,int steps){
rep(i,19)
{
if((steps>>i)&1){
now=fa[now][i];
}
}
return now;
}
int lca(int u,int v){
if(depth[u]>depth[v]){
swap(u,v);
}
v=jump(v,depth[v]-depth[u]);
if(u==v) return u;
rl(i,18,0){
if(fa[u][i]!=fa[v][i]){
u=fa[u][i];
v=fa[v][i];
}
}
return fa[u][0];
}
int cnt=0;
void dfs2(int now,int pre){
dfn[now]=++cnt;
if(heavy[now]){
top[heavy[now]]=top[now];
dfs2(heavy[now],now);
}
for(auto it:g[now]){
if(it!=pre&&it!=heavy[now]){
top[it]=it;
dfs2(it,now);
}
}
}
void change(int index,LL d){
int now=index;
while(now){
if(fa[top[now]][0]){
sq_sum.add(dfn[fa[top[now]][0]],d*d+2ll*d*sum.query(dfn[top[now]],dfn[top[now]]+siz[top[now]]-1));
}
now=fa[top[now]][0];
}
sum.add(dfn[index],d);
}
LL all(int index){
LL rest=0;
rest=sq_sum.query(dfn[index],dfn[index])%MOD;
if(heavy[index]){
LL summ=sum.query(dfn[heavy[index]],dfn[heavy[index]]+siz[heavy[index]]-1)%MOD;
rest+=summ*summ%MOD;
rest%=MOD;
}
return rest;
}
LL all_sum=0;
void modify(int index,LL d){
change(index,-w[index]);
change(index,d);
all_sum-=w[index];
all_sum+=MOD;
all_sum%=MOD;
all_sum+=d;
all_sum%=MOD;
w[index]=d;
}
LL query(int u,int v){
LL rest=0;
while(depth[u]>=depth[v]&&u){
int nex;
if(depth[top[u]]<depth[v]){
nex=v;
}
else{
nex=top[u];
}
rest+=all(u);
rest%=MOD;
rest-=sq_sum.query(dfn[u],dfn[u])%MOD;
rest+=sq_sum.query(dfn[nex],dfn[u])%MOD;
rest+=MOD;
rest%=MOD;
if(fa[nex][0]&&depth[fa[nex][0]]>=depth[v]){
LL sum_=sum.query(dfn[nex],dfn[nex]+siz[nex]-1)%MOD;
rest-=sum_*sum_%MOD;
rest+=MOD;
rest%=MOD;
}
u=fa[nex][0];
}
return rest;
}
signed main(){
fastio;
while(cin>>n>>q){
all_sum=0;
cnt=0;
rb(i,1,n)
g[i].clear(),heavy[i]=0;
rb(i,1,n)
rb(j,0,18)
fa[i][j]=0;
memset(sum.bit,0,sizeof(sum.bit));
memset(sq_sum.bit,0,sizeof(sq_sum.bit));
rb(i,1,n){
R(w[i]);
}
rb(i,1,n-1){
int u,v;
R2(u,v);
g[u].PB(v);
g[v].PB(u);
}
dfs1(1,0);
top[1]=1;
dfs2(1,0);
rb(i,1,n){
int save_wi=w[i];
w[i]=0;
modify(i,save_wi);
}
rb(i,1,q){
int ty;
R(ty);
if(ty==1){
int u,w;
R2(u,w);
modify(u,w);
}
else{
int u,v;
R2(u,v);
LL rest=0;
int llc=lca(u,v);
LL get_all=all(llc);
int nu,nv;
if(u!=llc){
nu=jump(u,depth[u]-depth[llc]-1);
rest+=query(u,nu);
rest%=MOD;
LL tmp_sum=sum.query(dfn[nu],dfn[nu]+siz[nu]-1)%MOD;
get_all-=tmp_sum*tmp_sum%MOD;
get_all+=MOD;
get_all%=MOD;
}
if(v!=llc){
nv=jump(v,depth[v]-depth[llc]-1);
rest+=query(v,nv);
LL tmp_sum=sum.query(dfn[nv],dfn[nv]+siz[nv]-1)%MOD;
get_all-=tmp_sum*tmp_sum%MOD;
get_all+=MOD;
get_all%=MOD;
}
rest+=get_all;
rest%=MOD;
assert(all_sum==sum.query(1,n));
LL tmp_sum=sum.query(1,n)-sum.query(dfn[llc],dfn[llc]+siz[llc]-1);
tmp_sum%=MOD;
tmp_sum+=MOD;
tmp_sum%=MOD;
rest+=tmp_sum*tmp_sum%MOD;
rest%=MOD;
tmp_sum=all_sum;
tmp_sum%=MOD;
assert((tmp_sum*tmp_sum%MOD-rest+MOD)%MOD>=0);
cout<<(tmp_sum*tmp_sum%MOD-rest+MOD)%MOD<<endl;
}
}
}
return 0;
}
/*
1 1
1
2 1 1
*/