题目来源
https://ac.nowcoder.com/acm/contest/4370/F
题意分析
给出一棵树,有四种操作:
1 x y w 表示将从x到y这条简单路径的上所有点权改成w。
2 x y w 表示将从x到y这条简单路径的上所有点权加上w。
3 x y w 表示将从x到y这条简单路径的上所有点权乘上w。
4 x y 表示求出从x到y这条简单路径上的所有点的点权的立方和。
思路分析
学过树链剖分的,其实一看就会发现是树链剖分的模板题。主要难题在如何处理点权的修改和乘法。
首先手动模拟一下立方的加法和乘法是什么情况,进而决定用线段树需要维护的数列。首先需要维护的是立方和的加法,而在立方和加法中,如果对于其中的某个值进行加法运算,那么去掉括号之后会发现他的值和平方和以及一次方和有关,所以再维护一下这两个数值。然后就是长时间的码农时间了。
code
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e5 +7;
const int mod = 1e9 + 7;
int n, m, r;
int head[maxn], nxt[maxn << 1], ver[maxn << 1];
int sz[maxn], dep[maxn], fa[maxn], top[maxn], dfn[maxn], id[maxn];
int hcnt = 0;
int tot;
ll sum1[maxn << 2], sum2[maxn << 2], sum3[maxn << 2], a[maxn];
ll lzmul[maxn << 2], lzadd[maxn << 2];
void adde(int u, int v){
++tot; ver[tot] = v; nxt[tot] = head[u]; head[u] = tot;
}
void dfs1(int x, int f, int d){
fa[x] = f; dep[x] = d; sz[x] = 1;
for (int i=head[x]; i; i=nxt[i]){
int v = ver[i];
if (v == f) continue;
dfs1(v, x, d + 1);
sz[x] += sz[v];
}
}
void dfs2(int x, int f){
dfn[x] = ++hcnt; id[hcnt] = x;
top[x] = f;
int pp = 0;
for (int i=head[x]; i; i=nxt[i]){
int v = ver[i];
if (v == fa[x]) continue;
if (sz[v] > sz[pp]) pp = v;
}
if (pp == 0) return;
dfs2(pp, f);
for (int i=head[x]; i; i=nxt[i]){
int v = ver[i];
if (v == pp || v == fa[x]) continue;
dfs2(v, v);
}
}
void pushup(int p){
sum1[p] = sum1[p << 1] + sum1[p << 1 | 1];
sum1[p] %= mod;
sum2[p] = sum2[p << 1] + sum2[p << 1 | 1];
sum2[p] %= mod;
sum3[p] = sum3[p << 1] + sum3[p << 1 | 1];
sum3[p] %= mod;
}
void change(int p, ll x, ll y, int ln){
if (x != 1){
ll w = 1ll * x % mod;
ll w2 = w * w % mod;
ll w3 = w2 * w % mod;
sum3[p] *= w3; sum3[p] %= mod;
sum2[p] *= w2; sum2[p] %= mod;
sum1[p] *= w; sum1[p] %= mod;
lzmul[p] *= w; lzmul[p] %= mod;
lzadd[p] *= w; lzadd[p] %= mod;
}
if (y != 0){
ll w = 1ll * y % mod;
ll w2 = w * w % mod;
ll w3 = w2 * w % mod;
sum3[p] += (1ll* ln * w3 % mod); sum3[p] %= mod;
sum3[p] += (3ll * w2 * sum1[p]) % mod; sum3[p] %= mod;
sum3[p] += (3ll * w * sum2[p]) % mod; sum3[p] %= mod;
sum2[p] += (1ll * ln * w2 % mod); sum2[p] %= mod;
sum2[p] += (2ll * w * sum1[p]) % mod; sum2[p] %= mod;
sum1[p] += 1ll * ln * w; sum1[p] %= mod;
lzadd[p] += w; lzadd[p] %= mod;
}
}
void pushdown(int p, int lnl, int lnr){
ll x = lzmul[p], y = lzadd[p];
change(p << 1, x, y, lnl);
change(p << 1 | 1, x, y, lnr);
lzmul[p] = 1; lzadd[p] = 0;
}
void build(int p, int l, int r){
lzmul[p] = 1; lzadd[p] = 0;
if (l == r){
sum1[p] = a[id[l]] % mod;
sum2[p] = sum1[p] * sum1[p] % mod;
sum3[p] = sum2[p] * sum1[p] % mod;
return;
}
int mid = l + r >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
pushup(p);
}
ll sum(int p, int l, int r, int L, int R){
// cout << p << " " << l << " " << r << " " << L << " " << R << endl;
if (L <= l && r <= R) return sum3[p];
int mid = l + r >> 1;
pushdown(p, mid + 1 - l, r - mid);
ll ans = 0;
if (mid >= L) ans += sum(p << 1, l, mid, L, R);
ans %= mod;
if (mid < R) ans += sum(p << 1 | 1, mid + 1, r, L, R);
ans %= mod;
pushup(p);
return ans;
}
void update(int p, int l, int r, int L, int R, int x, int y){
// cout << p << " " << l << " " << r << " " << L << " " << R << " " << x << " " << y << endl;
if (L <= l && r <= R){
change(p, 1ll*x, 1ll*y, r - l + 1); return;
}
int mid = l + r >> 1;
pushdown(p, mid + 1 - l, r - mid);
if (L <= mid) update(p << 1, l, mid, L, R, x, y);
if (mid < R) update(p << 1 | 1, mid + 1, r, L, R, x, y);
pushup(p);
}
void upd(int u, int v, int x, int y){
while (top[u] != top[v]){
if (dep[top[u]] < dep[top[v]]) swap(u, v);
update(1, 1, n, dfn[top[u]], dfn[u], x, y);
u = fa[top[u]];
}
if (dep[u] > dep[v]) swap(u, v);
update(1, 1, n, dfn[u], dfn[v], x, y);
}
ll ask(int u, int v){
ll ans = 0;
while (top[u] != top[v]){
if (dep[top[u]] < dep[top[v]]) swap(u, v);
// cout << "!!!! " << u << " " << dfn[top[u]] << " " << dfn[u] << endl;
ans += sum(1, 1, n, dfn[top[u]], dfn[u]);
ans %= mod;
u = fa[top[u]];
}
ans %= mod;
if (dep[u] < dep[v]) ans += sum(1, 1, n, dfn[u], dfn[v]);
else if (dep[u] >= dep[v]) ans += sum(1, 1, n, dfn[v], dfn[u]);
ans %= mod;
return ans;
}
int main(){
int t; scanf("%d", &t);
int cas = 0;
while (t --){
tot = 0; hcnt = 0;
scanf("%d", &n);
for (int i=0; i<=n; i++){
head[i] = 0; sz[i] = 0;id[i] = 0; fa[i] = 0;
}
int r = 1;
// cout << "???" << endl;
for (int i=1; i<n; i++){
int u, v;
scanf("%d%d", &u, &v);
adde(u, v); adde(v, u);
}
// cout << "1111" << endl;
for (int i=1; i<=n; i++) scanf("%d", &a[i]);
dfs1(r, 0, 1);
// cout << "2222" << endl;
dfs2(r, r);
// cout << "3333" << endl;
build(1, 1, n);
// cout << "4444" << endl;
int q; scanf("%d", &q);
printf("Case #%d:\n", ++cas);
while (q --){
int op; scanf("%d", &op);
int u, v; scanf("%d%d", &u, &v);
int w;
if (op == 1){
scanf("%d", &w);
upd(u, v, 0, w);
}else if (op == 2){
scanf("%d", &w);
upd(u, v, 1, w);
}else if (op == 3){
scanf("%d", &w);
upd(u, v, w, 0);
}else if (op == 4){
printf("%d\n", ask(u, v) % mod);
}
}
}
return 0;
}