题意:
给一棵树(n个点,n-1条边), 每个节点有权值。
四个操作:
- u, v, w 将u到v路径上所有的点权值赋值为w
- u, v, w 将u到v路径上所有的点权值增加w
- u, v, w 将u到v路径上所有的点权值乘w
- u, v 查询u 到 v路径上所有点 三次方的和
含u,v节点
思路:
树链剖分一定是没问题了,就是查询操作和改变操作有点麻烦:
线段树两个懒标记:mul记录乘操作 add记录加操作
- 情况一:mul = 0, add = num;
- 情况二:mul = 1, add = num;
- 情况三:mul = num, add = 0;
剩下就是每种情况对应的操作了需要记录节点的和, 平方和,立方和。
操作一,操作二:相当于 a^3 -> (a+b)^3 拆开就会发现 是a^3 + b^3 + 3ab^2 + 3ba^2; 平方操作类同
操作三:直接乘就好, 注意的是需要将懒标记add 对应相乘
这种题就是麻烦,码量大一点 也没什么技巧,
莫忘初始化 // 找半天bug 发现又是它
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
const int maxn = 1e5 + 7;
const int mod = 1e9 + 7;
int h[maxn], e[maxn<<1], ne[maxn<<1], w[maxn], cnt;
int n;
struct Node{
int l, r;
int mul, add;
int sum, ssum, csum; // 和, 平方和, 立方和
}T[maxn<<2];
void add(int u, int v){
e[cnt] = v;
ne[cnt] = h[u];
h[u] = cnt ++;
}
int son[maxn], top[maxn];
int sz[maxn], fa[maxn], seq[maxn], dseq[maxn], d[maxn];
void dfs_son(int u, int f){
sz[u] = 1;
for (int i = h[u]; ~i; i = ne[i]){
int v = e[i];
if (v == f) continue;
d[v] = d[u] + 1;
fa[v] = u;
dfs_son(v, u);
sz[u] += sz[v];
if (sz[v] > sz[son[u]]) son[u] = v;
}
}
int tot;
void dfs_seq(int u, int tp){
seq[u] = ++tot;
dseq[tot] = u;
top[u] = tp;
// cout << son[u] << endl;
if (son[u]) dfs_seq(son[u], tp);
for (int i = h[u]; ~i; i = ne[i]){
int v = e[i];
if (v == son[u] || v == fa[u]) continue;
dfs_seq(v, v);
}
}
void up(int k){
T[k].sum = (T[k<<1].sum + T[k<<1|1].sum)%mod;
T[k].ssum = (T[k<<1].ssum + T[k<<1|1].ssum)%mod;
T[k].csum = (T[k<<1].csum + T[k<<1|1].csum)%mod;
}
void build (int l = 1, int r = n, int k = 1){
T[k].l = l;T[k].r = r;
T[k].mul = 1, T[k].add = 0;
if (l == r){
int val = w[dseq[l]];
T[k].sum = val;
T[k].ssum = 1ll * val * val % mod;
T[k].csum = 1ll * val * val % mod * val % mod;
return;
}
int mid = l + r >> 1;
build (l, mid, k << 1);
build (mid + 1, r, k << 1 | 1);
up(k);
}
void add1(int &x, int y){if ((x += y) >= mod) x -= mod;}
void mul1(int &x, int y) {x = 1ll * x * y %mod;}
void update(int mu, int ad, int sz, int k){
if (mu != 1){
int val = mu;
int val2 = 1ll * mu * mu % mod;
int val3 = 1ll * val2 * mu % mod;
mul1(T[k].csum, val3);
mul1(T[k].ssum, val2);
mul1(T[k].sum, val);
mul1(T[k].mul, mu);
mul1(T[k].add, mu);
}
if (ad != 0){
int val = ad;
int val2 = 1ll * ad * ad % mod;
int val3 = 1ll * val2 * ad % mod;
add1(T[k].csum, 1ll * sz * val3 %mod);//a^3
add1(T[k].csum, 3ll * val * T[k].ssum % mod); //3ab^2
add1(T[k].csum, 3ll * T[k].sum * val2 % mod); // 3ba^2
add1(T[k].ssum, 1ll * sz * val2%mod); // a^2
add1(T[k].ssum, 2ll * val * T[k].sum % mod);// 2ab
add1(T[k].sum, 1LL * sz * val % mod);
add1(T[k].add, ad);
}
}
void pushDown(int szl, int szr, int k){
int x = T[k].mul, y = T[k].add;
update(x, y, szl, k << 1);
update(x, y, szr, k << 1 | 1);
T[k].mul = 1;
T[k].add = 0;
}
void modify1(int l, int r, int mu, int ad, int k = 1){
if (l > T[k].r || r < T[k].l) return;
if (l <= T[k].l && r >= T[k].r){
update(mu, ad, T[k].r - T[k].l + 1, k);
return;
}
int mid = T[k].l + T[k].r >> 1;
pushDown(mid - T[k].l + 1, T[k].r - mid, k);
if (l <= mid) modify1(l, r, mu, ad, k << 1);
if (r > mid) modify1(l, r, mu, ad, k << 1 | 1);
up(k);
}
void modify(int x, int y, int mu, int ad){
while (top[x] != top[y]){
if (d[top[x]] < d[top[y]]) swap(x, y);
modify1(seq[top[x]], seq[x], mu, ad);
x = fa[top[x]];
}
if (d[x] > d[y]) swap(x, y);
modify1(seq[x], seq[y], mu, ad);
}
int query1(int l, int r, int k = 1){
if (l > T[k].r || r < T[k].l) return 0;
if (l <= T[k].l && r >= T[k].r){
return T[k].csum;
}
int mid = T[k].l + T[k].r >> 1;
pushDown(mid - T[k].l + 1, T[k].r - mid, k);
int res = 0;
if (l <= mid) add1(res, query1(l, r, k << 1));
if (r > mid) add1(res, query1(l, r, k << 1 | 1));
up(k);
return res;
}
int query(int x, int y){
int res = 0;
while (top[x] != top[y]){
if (d[top[x]] < d[top[y]]) swap(x, y);
add1(res, query1(seq[top[x]], seq[x]));
x = fa[top[x]];
}
if (d[x] > d[y]) swap(x, y);
add1(res, query1(seq[x], seq[y]));
return res;
}
int main (){
int T, tol = 1;
scanf ("%d", &T);
while (T -- ){
scanf ("%d", &n);
//初始化 !!!!!!
for (int i = 0; i <= n; i ++ ) h[i] = -1, son[i] = 0;
cnt = tot = 0;
for (int i = 1, u, v; i < n; i ++ ){
scanf ("%d%d", &u, &v);
add(u, v);
add(v, u);
}
for (int i = 1; i <= n; i ++ ) scanf ("%d", &w[i]);
d[1] = 1, fa[1] = 1; // 总是忘记!!!
dfs_son(1, -1); //
dfs_seq(1, 1);
// for (int i = 1; i <= n; i ++ ) cout << seq[i] << endl;
build();
int m;
//cout << query(1,n) <<endl;
scanf ("%d", &m);
printf ("Case #%d:\n", tol ++);
while (m -- ){
int opt, x, y, num;
scanf ("%d%d%d", &opt, &x, &y);
if (opt == 4){
printf ("%d\n", query(x,y));
}
else {
scanf ("%d", &num);
if (opt == 1){
modify(x, y, 0, num);
}
else if (opt == 2){
modify(x, y, 1, num);
}
else modify(x, y, num, 0);
}
// 建议不要用switch 我测了一下 运行速度慢差不多300ms
/* switch(opt){
case 1:{
scanf ("%d", &num);
modify(x, y, 0, num);
break;
}
case 2:{
scanf ("%d", &num);
modify(x, y, 1, num);
break;
}
case 3:{
scanf ("%d", &num);
modify(x, y, num, 0);
break;
}
case 4:{
printf ("%d\n", query(x,y));
break;
}
}
*/
}
}
return 0;
}
/*
1
5
2 1
1 3
5 3
4 3
1 2 3 4 5
6
4 2 4
1 5 4 2
2 2 4 3
3 2 3 4
4 5 4
4 2 4
Case #1:
100
8133
20221
*/