P3384 【模板】轻重链剖分
题意
已知一棵包含 N 个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作 1: 格式: 1 x y z 表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z。
操作 2: 格式: 2 x y 表示求树从 x 到 y 结点最短路径上所有节点的值之和。
操作 3: 格式: 3 x z 表示将以 x 为根节点的子树内所有节点值都加上 z。
操作 4: 格式: 4 x 表示求以 x 为根节点的子树内所有节点值之和。
思路
先将树轻重链剖分后,遍历整棵树的时候优先遍历重链,对于每一条链来说,链上节点对应的dfs序是连续的,所以我们可以很容易得利用一些数据结构对链进行操作。
先通过两次dfs得到下面的数组:
int f[maxN]; //父亲结点
int d[maxN]; //结点深度
int siz[maxN]; //子树大小
int son[maxN]; //重儿子标号
int dfn[maxN]; //dfs序
int rk[maxN]; //dfs序对应的结点
int top[maxN]; //结点所在链的顶端结点
对于操作1 2:我们每次找两个节点间的路径的时候,可以一直跳top[]一直跳到两个节点在同一条链上为止。这中间跳过的每一条链都需要进行加或者查询操作。这就涉及线段树区间加操作和区间和查询了。
对于操作3 4:我们事先已经统计了每个子树的节点个数(子树大小),而dfs序的特点就是一颗子树上节点标号是连续的,所以我们可以直接对区间 [ d f s [ x ] , d f s [ x ] + s i z [ x ] − 1 ] [dfs[x], dfs[x] + siz[x] - 1] [dfs[x],dfs[x]+siz[x]−1]进行操作。
Code
#include <bits/stdc++.h>
#define MID (l + r) >> 1
#define lsn rt << 1
#define rsn rt << 1 | 1
#define Lson lsn, l, mid
#define Rson rsn, mid + 1, r
#define QL Lson, ql, qr
#define QR Rson, ql, qr
using namespace std;
typedef long long ll ;
int read() {
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9') { if(ch == '-') f = -f; ch = getchar(); }
while(ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
return x * f;
}
const int maxN = 1e5 + 10;
struct EDGE{
int adj, to;
}edge[maxN << 1];
int head[maxN], cnt, tot;
void init() {
memset(head, -1, sizeof(head));
cnt = tot = 0;
}
void add_edge(int u, int v) {
edge[cnt] = EDGE{head[u], v};
head[u] = cnt ++ ;
}
/*-------------------------------------------------*/
int f[maxN]; //父亲结点
int d[maxN]; //结点深度
int siz[maxN]; //子树大小
int son[maxN]; //重儿子标号
/*dfs1(): 得到上面四个数组*/
void dfs1(int u, int fa, int depth) {
f[u] = fa;
d[u] = depth;
siz[u] = 1;
for(int i = head[u]; ~i; i = edge[i].adj) {
int v = edge[i].to;
if(v == fa) continue;
dfs1(v, u, depth + 1);
siz[u] += siz[v];
if(siz[son[u]] < siz[v]) {
son[u] = v;
}
}
}
/*-------------------------------------------------*/
int dfn[maxN]; //dfs序
int rk[maxN]; //dfs序对应的结点
int top[maxN]; //结点所在链的顶端结点
/*dfs2: 得到上面三个数组*/
void dfs2(int u, int tp) {
dfn[u] = ++tot;
rk[tot] = u;
top[u] = tp;
if (!son[u]) return;
dfs2(son[u], tp); //优先选择重儿子来保证一条重链上各个结点的dfs序连续,某结点和其重儿子一定在同一条中脸上,所以顶端不变
for (int i = head[u]; ~i; i = edge[i].adj) {
int v = edge[i].to;
if (v != son[u] && v != f[u]) dfs2(v, v); //v位于轻链顶端,那这条链的顶端必然是其本身
}
}
/*-------------------------------------------------*/
//线段树部分
int n, m, root;
ll p, a[maxN];
ll sum[maxN << 2], add[maxN << 2]; //区间最大值,区间和
void pushup(int rt) {
sum[rt] = sum[lsn] + sum[rsn];
sum[rt] %= p;
}
void pushdown(int rt, int l, int r) {
if(add[rt]) {
int mid = MID;
sum[lsn] += add[rt] * (ll)(mid - l + 1); sum[lsn] %= p;
sum[rsn] += add[rt] * (ll)(r - mid); sum[rsn] %= p;
add[lsn] += add[rt]; add[lsn] %= p;
add[rsn] += add[rt]; add[rsn] %= p;
add[rt] = 0;
}
}
void build(int rt, int l, int r) {
if(l == r) { sum[rt] = a[rk[l]] % p; return ;}
int mid = MID;
build(Lson);
build(Rson);
pushup(rt);
}
void update(int rt, int l, int r, int ql, int qr, int val) {
if(ql <= l && qr >=r ) {
sum[rt] += val * (r - l + 1); sum[rt] %= p;
add[rt] += val; add[rt] %= p;
return ;
}
pushdown(rt, l, r);
int mid = MID;
if(qr <= mid) update(QL, val);
else if(ql > mid) update(QR, val);
else update(QL, val), update(QR, val);
pushup(rt);
}
//查询区间和
ll query(int rt, int l, int r, int ql, int qr) {
if(ql <= l && qr >= r) return sum[rt] % p;
int mid = MID;
pushdown(rt, l, r);
if(qr <= mid) return query(QL) % p;
else if(ql > mid) return query(QR) % p;
else return (query(QL) + query(QR)) % p;
}
/*-------------------------------------------------*/
//求LCA
int getLCA(int x, int y) {
while(top[x] != top[y]) { //两个点不在同一条链上
if(d[top[x]] < d[top[y]]) swap(x, y); //使x表示所在链深度更深的点
x = f[top[x]]; //将x更新为所在链顶端结点的父亲节点
}
//两个点在同一条链上,深度更小的点为LCA
return d[x] < d[y] ? x : y;
}
/*-------------------------------------------------*/
/*求结点x到结点y的最短路径
* :一直加重链的贡献,直到xy在同一条链上,则加上两点间的贡献即可。
*/
int getDis(int x, int y) {
ll ans = 0;
while(top[x] != top[y]) {
if(d[top[x]] < d[top[y]]) swap(x, y); //使x表示所在链深度更深的点
ans += query(1, 1, n, dfn[top[x]], dfn[x]); //加上x结点这条重链上的贡献
ans %= p;
x = f[top[x]];
}
if(d[x] > d[y]) swap(x, y); //使x表示所在链深度更浅的点,深度更深标号越靠后
ans += query(1, 1, n, dfn[x], dfn[y]);
ans %= p;
return (int)ans;
}
/*-------------------------------------------------*/
/*修改结点x到结点y最短路径上的点权*/
void UPDATE(int x, int y, int val) {
while(top[x] != top[y]) {
if(d[top[x]] < d[top[y]]) swap(x, y);
update(1, 1, n, dfn[top[x]], dfn[x], val);
x = f[top[x]];
}
if(d[x] > d[y]) swap(x, y);
update(1, 1, n, dfn[x], dfn[y], val);
}
int main() {
// freopen("in.in", "r", stdin);
init();
n = read(); m = read(); root = read();
p = (ll)read();
for(int i = 1; i <= n; ++ i ) {
a[i] = (ll)read();
}
for(int i = 0; i < n - 1; ++ i ) {
int u = read(), v = read();
add_edge(u, v), add_edge(v, u);
}
dfs1(root, 0, 1);
dfs2(root, root);
build(1, 1, n);
while(m -- ) {
int op, x, y;
ll z;
op = read();
if(op == 1) { //区间修改
x = read(), y = read(); z = (ll)read();
UPDATE(x, y, z);
} else if(op == 2){ //区间查询点权和
x = read(), y = read();
printf("%d\n", getDis(x, y));
// fflush(stdout);
} else if(op == 3){ //子树修改
x = read(), z = (ll) read();
update(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, z);
} else { //子树查询
x = read();
printf("%d\n", (int)query(1, 1, n, dfn[x], dfn[x] + siz[x] - 1));
// fflush(stdout);
}
}
return 0;
}
哦,真的会有人卡线段树区间加吗,是的真的有人
哦,真的有人会边不开两倍,线段树不开四倍吗,是的真的有人
GG