思路:列出求和的表达式,然后用树状数组进行维护。这题使我明白了树链剖分的意义在于将节点重新标号,然后使得重链的节点标号连续,这样可以将树的查询问题转化为多个区间查询问题子问题,这样我们考虑问题的时候只需要考虑区间查询即可。
坑点:模数为1e9 + 7 输入的范围却为 2e9。输入之后一定要取模。
#include <bits/stdc++.h>
using namespace std;
typedef long long lint;
typedef long long LL;
const lint maxn_bit = 300005;
const lint maxn = 300005;
const lint maxm = 600005;
const lint mod = 1e9+7;
inline int lowbit(int n) {
return n & -n;
}
lint sub( lint a,lint b ){
return ((long long)a-b+mod )%mod;
}
lint mul( lint a,lint b ){
return (long long)a*b%mod;
}
lint ad( lint a,lint b ){
return (a+b+mod)%mod;
}
lint S[maxn];
template<typename T>
struct BIT {
T s[maxn_bit]; //s[0] is invalid
int n;
void init(int _n) {
n = _n;
memset(s, 0, (n+1) * sizeof(s[0]));
}
void init(int _n, T a[]) {
static T sum[maxn_bit];
n = _n;
for (int i = 1; i <= n; ++i) {
sum[i] = ad(sum[i-1] , a[i]);
s[i] = sub(sum[i] , sum[i - lowbit(i)]);
}
}
void Add(int x, T v) {
while( x <= n ) {
s[x] =ad(s[x], v);
x += lowbit(x);
}
}
//The sum of [1, x]
T Sum(int x) {
T ans = 0;
while (x) {
ans =ad(ans, s[x]);
x ^= lowbit(x);
}
return ans;
}
};
LL fav[maxn];
lint cost[maxn];
lint tot,he[maxn],ver[maxm],ne[maxm],id[maxn];
LL v[maxn];
void add( lint x,lint y){
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
lint sz[maxn],f[maxn],son[maxn],d[maxn];
void dfs1(lint x){
sz[x] = 1;
lint mm = 0;son[x] = 0;
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == f[x] ) continue;
d[y] = d[x]+1;
f[y] = x;
dfs1(y);
sz[x] += sz[y];
if( sz[y] > mm ){
mm = sz[y];
son[x] = y;
}
}
}
lint top[maxn],h[maxn],num;
void dfs2( lint x ){
h[x] = ++num;
id[num]=x;
if( son[ f[x] ] == x ){
top[x] = top[ f[x] ];
}else{
top[x] = x;
}
if( son[x] ){
dfs2(son[x]);
}
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == f[x] || y == son[x] )continue;
dfs2(y);
}
}
lint solve1( lint x,lint y ){
lint res = 0;
while( top[x] != top[y] ){
if( d[top[x]] < d[ top[y] ] ){
swap(x,y);
}
res += S[ h[x] ] - S[ h[ top[x] ]-1 ];
x = f[ top[x] ];
}
if( d[x] < d[y] ){
swap(x,y);
}
res += S[ h[x] ] - S[ h[y]-1 ];
return res;
}
void update( LL a[],BIT<LL>& bit,lint x,LL v ){
bit.Add( x,-a[x]+mod );
a[x] = v;
bit.Add( x,a[x] );
}
LL ask( BIT<LL>& bit,lint l,lint r ){
if( l > r ) swap(l,r);
LL ll = bit.Sum(l-1);
LL rr = bit.Sum( r );
return sub(rr,ll);
}
BIT<LL> bit;
BIT<LL> bit2;
BIT<LL> bit3;
LL solve_ask( lint x,lint y,LL v1 ){
lint flag = 1;
LL res = 0;
LL v2 = v1 - solve1( x,y );
while( top[x] != top[y] ){
if( d[top[x]] < d[ top[y] ] ){
swap(x,y);flag ^= 1;
}
if( flag ){
v1 -= solve1( x,top[x] );
LL x1 = mul(v1, ask( bit,h[x],h[top[x]] ) ) ;
LL x2 = ask( bit3,h[x],h[top[x]] );
LL x3 = mul( S[ h[top[x]]-1],ask( bit,h[x],h[ top[x] ] ) );
res = ad(res,sub(ad( x1,x2 ),x3));
}else{
v2 += solve1( x,top[x] );
LL x1 = mul( v2,ask( bit,h[x],h[ top[x] ] ));
LL x2 = h[x] >= h[ top[x] ]+1 ?ask( bit2,h[x],h[ top[x] ]+1 ) : 0;
LL x3 = h[x] >= h[ top[x] ]+1 ?mul( S[h[top[x]]-1],ask( bit,h[x],h[ top[x] ]+1 ) ):0;
res = ad( res,ad(sub( x1,x2 ),x3) );
}
x = f[ top[x] ];
}
if( d[x] < d[y] ) {
swap(x, y);flag^=1;
}
if( flag ){
v1 -= solve1( x,y );
LL x1 = mul(v1, ask( bit,h[x],h[y] ) ) ;
LL x2 = ask( bit3,h[x],h[y] );
LL x3 = mul( S[ h[y]-1],ask( bit,h[x],h[y] ) );
res = ad(res,sub(ad( x1,x2 ),x3));
}else{
v2 += solve1( x,y );
LL x1 = mul( v2,ask( bit,h[x],h[y] ));
LL x2 = h[x] >=h[y]+1? ask( bit2,h[x],h[y]+1 ):0;
LL x3 = h[x] >= h[y]+1?mul( S[h[y]-1],ask( bit,h[x],h[y]+1 ) ):0;
res = ad( res,ad(sub( x1,x2 ),x3) );
}
return res;
}
LL favs[maxn],fa[maxn],favs2[maxn];
void init( lint n ){
tot = 1;
f[1] = d[1] = 0;
num = 0;
for( lint i = 0;i <= n;i++ ) he[i] = 0;
}
int main(){
lint T,n;
scanf("%lld",&T);
while(T--){
scanf("%lld",&n);
init(n);
for( lint i = 1;i <= n;i++ ) {
scanf("%lld", &fa[i]);
fa[i] %= mod;
}
for( lint i = 1;i <= n;i++ )scanf("%lld",&cost[i]);
for( lint i = 1;i <= n-1;i++ ){
lint x,y;
scanf("%lld%lld",&x,&y);
add(x,y);add(y,x);
}
dfs1(1);dfs2(1);
S[0] = 0;
for( lint i = 1;i <= n;i++ ){
S[ i ] = S[i-1]+cost[ id[i] ];
}
for( lint i = 1;i <= n;i++ ){
fav[i] = fa[ id[i] ];
}
for( lint i = 1;i <= n;i++ ){
favs2[i] = mul( fav[i],S[i] );
}
for( lint i = 1;i <= n;i++ ){
favs[i] = mul( fav[i ],S[i-1] );
}
bit.init(n,fav);bit2.init(n,favs);bit3.init( n,favs2 );
lint q;
scanf("%lld",&q);
for( lint op,i = 1;i <= q;i++ ){
scanf("%lld",&op);
lint x,y;
LL v;
if( op == 1 ){
scanf("%lld%lld%lld",&x,&y,&v);
v%=mod;
LL ans = solve_ask(x,y,v);
cout << ans << endl;
}else{
scanf("%lld%lld",&x,&v);
v%=mod;
update(fav,bit,h[x],v);
update( favs,bit2,h[x],mul( v,S[ h[x]-1 ] ) );
update( favs2,bit3,h[x],mul( v,S[ h[x] ] ) );
}
}
}
return 0;
}