答案相当于权值之和的平方减去把两点路径上的点扣掉以后所有子树的平方和。
然后用线段树维护每个点的轻儿子的权值平方和。 维护起来相当恶心, 我写了一晚上才调出来。。
#include<bits/stdc++.h> #define LL long long #define LD long double #define ull unsigned long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ALL(x) (x).begin(), (x).end() #define fio ios::sync_with_stdio(false); cin.tie(0); using namespace std; const int N = 1e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; const double PI = acos(-1); template<class T, class S> inline void add(T &a, S b) {a += b; if(a >= mod) a -= mod;} template<class T, class S> inline void sub(T &a, S b) {a -= b; if(a < 0) a += mod;} template<class T, class S> inline bool chkmax(T &a, S b) {return a < b ? a = b, true : false;} template<class T, class S> inline bool chkmin(T &a, S b) {return a > b ? a = b, true : false;} //mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); int n, q, now; LL w[N], sum[N], val[N]; vector<int> G[N]; int depth[N], top[N], son[N], sz[N], pa[N]; int in[N], ot[N], rk[N], idx; struct Bit { int a[N]; void init() { for(int i = 1; i <= n; i++) { a[i] = 0; } } void modify(int x, int v) { for(int i = x; i <= n; i += i & -i) { add(a[i], v); } } int sum(int x) { int ans = 0; for(int i = x; i; i -= i & -i) { add(ans, a[i]); } return ans; } int query(int L, int R) { if(L > R) return 0; return (sum(R) - sum(L - 1) + mod) % mod; } } bit; struct SegmentTree { #define lson l, mid, rt << 1 #define rson mid + 1, r, rt << 1 | 1 int a[N << 2]; void build(int l, int r, int rt) { if(l == r) { a[rt] = val[rk[l]]; return; } int mid = l + r >> 1; build(lson); build(rson); a[rt] = a[rt << 1] + a[rt << 1 | 1]; if(a[rt] >= mod) a[rt] -= mod; } void update(int p, int val, int l, int r, int rt) { if(l == r) { add(a[rt], val); return; } int mid = l + r >> 1; if(p <= mid) update(p, val, lson); else update(p, val, rson); a[rt] = a[rt << 1] + a[rt << 1 | 1]; if(a[rt] >= mod) a[rt] -= mod; } int query(int L, int R, int l, int r, int rt) { if(R < l || r < L || R < L) return 0; if(L <= l && r <= R) return a[rt]; int mid = l + r >> 1; return (query(L, R, lson) + query(L, R, rson)) % mod; } } Tree; void dfs(int u, int fa) { pa[u] = fa; sz[u] = 1; sum[u] = w[u]; depth[u] = depth[fa] + 1; for(auto &v : G[u]) { if(v == fa) continue; dfs(v, u); sz[u] += sz[v]; add(sum[u], sum[v]); if(sz[son[u]] < sz[v]) { son[u] = v; } } } void dfs2(int u, int fa, int from) { in[u] = ++idx; rk[idx] = u; top[u] = from; if(son[u]) dfs2(son[u], u, from); for(auto &v : G[u]) { if(v == fa || v == son[u]) continue; dfs2(v, u, v); add(val[u], 1LL * sum[v] * sum[v] % mod); } ot[u] = idx; } inline int SUM(int u) { if(!u) return 0; return bit.query(in[u], ot[u]); } void update(int u, int val) { int fu = top[u]; while(fu) { if(pa[fu]) { int now = SUM(fu); int pre = (now - val + mod) % mod; Tree.update(in[pa[fu]], mod - 1LL * pre * pre % mod, 1, n, 1); add(sum[fu], val); Tree.update(in[pa[fu]], 1LL * now * now % mod, 1, n, 1); } u = pa[fu]; fu = top[u]; } } int query(int u, int v) { int ans = 0; int fu = top[u], fv = top[v]; while(fu != fv) { if(depth[fu] < depth[fv]) { add(ans, Tree.query(in[fu], in[u], 1, n, 1)); add(ans, 1LL * SUM(son[u]) * SUM(son[u]) % mod); if(pa[fu]) { sub(ans, 1LL * SUM(fu) * SUM(fu) % mod); } u = pa[fu]; fu = top[u]; } else { add(ans, Tree.query(in[fv], in[v], 1, n, 1)); add(ans, 1LL * SUM(son[v]) * SUM(son[v]) % mod); if(pa[fv]) { sub(ans, 1LL * SUM(fv) * SUM(fv) % mod); } v = pa[fv]; fv = top[v]; } } if(depth[u] <= depth[v]) { add(ans, Tree.query(in[u], in[v], 1, n, 1)); add(ans, 1LL * SUM(son[v]) * 1LL * SUM(son[v]) % mod); int tmp = (now - bit.query(in[u], ot[u]) + mod) % mod; add(ans, 1LL * tmp * tmp % mod); } else { add(ans, Tree.query(in[v], in[u], 1, n, 1)); add(ans, 1LL * SUM(son[u]) * 1LL * SUM(son[u]) % mod); int tmp = (now - bit.query(in[v], ot[v]) + mod) % mod; add(ans, 1LL * tmp * tmp % mod); } return ans; } void init() { idx = 0; bit.init(); for(int i = 1; i <= n; i++) { G[i].clear(); val[i] = 0; sum[i] = 0; son[i] = 0; sz[i] = 0; } } int main() { while(scanf("%d%d", &n, &q) != EOF) { init(); for(int i = 1; i <= n; i++) { scanf("%lld", &w[i]); } for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } dfs(1, 0); dfs2(1, 0, 1); for(int i = 1; i <= n; i++) { bit.modify(in[i], w[i]); } now = sum[1]; Tree.build(1, n, 1); while(q--) { int op, u, v; scanf("%d%d%d", &op, &u, &v); if(op == 1) { int tmp = (v - w[u] + mod) % mod; sub(now, w[u]); bit.modify(in[u], mod - w[u]); w[u] = v; add(now, w[u]); bit.modify(in[u], w[u]); update(u, tmp); } else { printf("%lld\n", (1LL * now * now % mod - query(u, v) + mod) % mod); } } } return 0; } /* */