【数据结构 树 树链剖分】luogu_3384 树链剖分

题意

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作 1 1 1: 格式: 1 1 1 x x x y y y z z z 表示将树从 x x x y y y结点最短路径上所有节点的值都加上 z z z

操作 2 2 2: 格式: 2 2 2 x x x y y y 表示求树从 x x x y y y结点最短路径上所有节点的值之和

操作 3 3 3: 格式: 3 3 3 x x x z z z 表示将以 x x x为根节点的子树内所有节点值都加上 z z z

操作 4 4 4: 格式: 4 4 4 x x x 表示求以 x x x为根节点的子树内所有节点值之和

思路

树链剖分,就是把一棵树剖成若干链,然后在树上的操作就转变成在序列里的操作,如何剖是关键。

定义如下数组:
f a t h e r father father,记录节点的me父亲
d e p dep dep,记录节点的深度
s i z e size size,记录节点的子树大小
s o n son son,记录节点的重儿子,就是该节点 s i z e size size最大的儿子
r e v rev rev,记录这个 d f s dfs dfs序对应的原来的节点
t o p top top,记录当前点处在的重链的顶端
s e g seg seg,记录当前点的 d f s dfs dfs

2 2 2 d f s dfs dfs即可求出,计算 s e g seg seg时,同一条重链上的点需要按顺序排在连续的一段区间。
这样子把树拆成若干条链,用线段树维护其。

树链剖分的两个性质:

1 , 1, 1如果 ( u , v ) (u,v) (u,v)是一条轻边,那么 s i z e ( v ) ≤ s i z e ( u ) / 2 size(v)\leq size(u)/2 size(v)size(u)/2;

2 , 2, 2从根结点到任意结点的路所经过的轻重链的个数必定都小于 l o g n logn logn;

因此树链剖分的总复杂度是 O ( l o g 2 n ) O(log^2n) O(log2n)的。

对于本题的 4 4 4种操作:
1 、 1、 1若两点不在同一条重链上,将 t o p top top较深的点往 f a t h e r [ t o p ] father[top] father[top]上跳,并更新跳过的区间。直到跳到一条重链上,这样就可以直接更新。
2 、 2、 2 1 1 1,不过将修改操作换成查询操作。
3 、 3、 3同一个子树的 d f s dfs dfs序一定是连续的,所以查询 x ∼ x + s i z e x − 1 x\sim x+size_x-1 xx+sizex1间的答案。
4 、 4、 4 3 3 3,查询操作变成修改操作。

其实操作 1 1 1的跳就相当于求 L C A LCA LCA,只不过这里直接跳完一个链。线段树就基本操作吧。

代码

#include <cctype>
#include <cstdio>
#include <algorithm>
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)

char buf[1 << 21], *p1 = buf, *p2 = buf;

struct segmentTree {
	int ls, rs, l, r, sum, lazy;
}t[400001];
int father[100001], dep[100001], size[100001], son[100001], rev[100001], top[100001], seg[100001];
int head[100001], ver[200001], next[200001];
int a[100001];
int n, m, r, mod, tot, cnt;

inline long long read() {
    long long res = 0, f = 0;
	char ch = getchar();
    while (!isdigit(ch)) {
        if (ch == '-') f = 1;
        ch = getchar();
	}
    while (isdigit(ch)) res = res * 10 + ch - 48, ch = getchar();
    return f ? -res : res;
}

void add(int u, int v) {
	ver[++tot] = v;
	next[tot] = head[u];
	head[u] = tot;
}

void dfs1(int u, int fa) {
	father[u] = fa;
	size[u] = 1;
	dep[u] = dep[fa] + 1;
	for (int i = head[u]; i; i = next[i]) {
		if (ver[i] == fa) continue;
		dfs1(ver[i], u);
		size[u] += size[ver[i]];
		if (size[ver[i]] > size[son[u]]) son[u] = ver[i];
	}
}

void dfs2(int u, int t) {
	top[u] = t;
	seg[u] = ++cnt;
	rev[cnt] = u;
	if (!son[u]) return;
	dfs2(son[u], t);
	for (int i = head[u]; i; i = next[i]) {
		if (top[ver[i]]) continue;
		dfs2(ver[i], ver[i]);
	}
}

void build(int p, int l, int r) {
	t[p].l = l, t[p].r = r;
	if (l == r) {
		t[p].sum = a[rev[l]] % mod;
		return;
	}
	int mid = l + r >> 1;
	build(t[p].ls = ++cnt, l, mid);
	build(t[p].rs = ++cnt, mid + 1, r);
	t[p].sum = (t[t[p].ls].sum + t[t[p].rs].sum) % mod;
}

inline void spread(int p) {
	if (!t[p].lazy) return;
	t[t[p].ls].lazy = (t[t[p].ls].lazy + t[p].lazy) % mod;
	t[t[p].rs].lazy = (t[t[p].rs].lazy + t[p].lazy) % mod;
	t[t[p].ls].sum = (t[t[p].ls].sum + (t[t[p].ls].r - t[t[p].ls].l + 1) * t[p].lazy) % mod;
	t[t[p].rs].sum = (t[t[p].rs].sum + (t[t[p].rs].r - t[t[p].rs].l + 1) * t[p].lazy) % mod;
	t[p].lazy = 0;
}

int query(int p, int l, int r) {
	if (l <= t[p].l && t[p].r <= r)
		return t[p].sum;
	spread(p);
	int mid = t[p].l + t[p].r >> 1, res = 0;
	if (l <= mid) res = (res + query(t[p].ls, l, r)) % mod;
	if (r > mid) res = (res + query(t[p].rs, l, r)) % mod;
	return res;
}

inline int ask(int x, int y) {
	int fx = top[x], fy = top[y], res = 0;
	while (fx != fy) {
		if (dep[fx] < dep[fy]) std::swap(x, y), std::swap(fx, fy);
		res = (res + query(1, seg[fx], seg[x])) % mod;
		x = father[fx];
		fx = top[x];
	}
	if (dep[x] > dep[y]) std::swap(x, y);
	res = (res + query(1, seg[x], seg[y])) % mod;
	return res;
}

void update(int p, int l, int r, int val) {
	if (l <= t[p].l && t[p].r <= r) {
		t[p].sum = (t[p].sum + (t[p].r - t[p].l + 1) * val) % mod;
		t[p].lazy = (t[p].lazy + val) % mod;
		return;
	}
	spread(p);
	int mid = t[p].l + t[p].r >> 1;
	if (l <= mid) update(t[p].ls, l, r, val);
	if (r > mid) update(t[p].rs, l, r, val);
	t[p].sum = (t[t[p].ls].sum + t[t[p].rs].sum) % mod; 
}

void modify(int x, int y, int val) {
	int fx = top[x], fy = top[y];
	while (fx != fy) {
		if (dep[fx] < dep[fy]) std::swap(x, y), std::swap(fx, fy);
		update(1, seg[fx], seg[x], val);
		x = father[fx];
		fx = top[x];
	}
	if (dep[x] > dep[y]) std::swap(x, y);
	update(1, seg[x], seg[y], val);
}

signed main() {
	n = read(), m = read(), r = read(), mod = read();
	for (register int i = 1; i <= n; i++)
		a[i] = read();
	for (register int i = 1; i < n; i++) {
		int x = read(), y = read();
		add(x, y), add(y, x);
	}
	dfs1(r, 0);
	dfs2(r, r);
	build(cnt = 1, 1, n);
	for (int op, x, y, z; m; m--) {
		op = read();
		if (op == 1) {
			x = read(), y = read(), z = read() % mod;
			modify(x, y, z);
		} else if (op == 2) {
			x = read(), y = read();
			printf("%d\n", ask(x, y));
		} else if (op == 3) {
			x = read(), z = read() % mod;
			update(1, seg[x], seg[x] + size[x] - 1, z);
		} else {
			x = read();
			printf("%d\n", query(1, seg[x], seg[x] + size[x] - 1));
		}
	}
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值