树链剖分
树链剖分用于将树分割成若干条链的形式,以维护树上路径的信息。
定义
重子节点 :其子节点中子树最大的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。
轻子节点: 表示剩余的所有子结点。
从这个结点到重子节点的边为 重边。
到其他轻子节点的边为 轻边。
若干条首尾衔接的重边构成 重链。
初始化
两遍dfs
第一次处理每个节点的dep(深度) , fa(父亲节点) ,sz(子树大小), son(重链结点)
void dfs1(int u , int v , int de ){
dep[u] = de ; fa[u] = v ; sz[u] = 1 ;
for(auto t : ve[u] ){
if(t == v ) continue ;
dfs1(t , u , de + 1 ) ;
sz[u] += sz[t] ;
if(sz[son[u]] < sz[t]) son[u] = t ;
}
}
第二遍处理 id(dfs序) ,nw(dfs序中对应的点权值) , top(重链的头)
void dfs2(int u ,int v, int f ) {
id[u] = ++ cnt , nw[cnt] = w[u] , top[u] = f ;
if(!son[u]) return ;
dfs2(son[u] ,u, f ) ;
for(auto t : ve[u]) {
if(t == v ) continue ;
if(t == son[u]) continue ;
dfs2(t , u , t ) ;
}
}
线段树操作
push_up
void push_up(int u ){
tr[u].sum = tr[u << 1].sum + tr[u<< 1 | 1].sum ;
}
push_down
void push_down(int u ) {
if(tr[u].flag != 0 ) {
ll &f = tr[u].flag;
node &l = tr[u << 1] , &r = tr[u << 1 | 1 ] ;
l.flag += f , r.flag += f;
l.sum += f * (l.r - l.l + 1 ) , r.sum += f*(r.r - r.l + 1 ) ;
f = 0 ;
}
}
build
void build(int u , int l ,int r) {
tr[u] = {l , r } ;
if(l == r ) {
tr[u] = {l , r , nw[l] , 0 } ;
return ;
}
int mid = l + r >> 1 ;
build(u << 1 , l , mid) , build(u << 1 | 1 , mid + 1 , r ) ;
push_up(u);
}
modify
void modify(int u , int l , int r , int k ) {
if(tr[u].l >= l && tr[u].r <= r ) {
tr[u].flag += k ;
tr[u].sum += k * (tr[u].r - tr[u].l + 1 ) ;
return ;
}
push_down(u) ;
int mid = tr[u].l + tr[u].r >> 1 ;
if(l <= mid ) modify(u << 1 , l , r , k ) ;
if(mid < r ) modify(u << 1 | 1, l ,r , k ) ;
push_up(u) ;
}
query
ll query(int u , int l , int r ) {
if(tr[u].l >= l && tr[u].r <= r ) return tr[u].sum ;
push_down(u) ;
int mid = tr[u].l + tr[u].r >> 1 ;
ll sum = 0 ;
if(l <= mid ) sum += query(u << 1 , l , r ) ;
if(mid < r ) sum += query(u << 1 | 1 , l , r ) ;
return sum ;
}
线段树和树结合
修改两点之间的路径全部点的权值
···c++
void modify_path(int u , int v , int k ){
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u , v ) ;
modify(1 , id[top[u]] , id[u] , k ) ;
u = fa[top[u]] ;
}
if(dep[u] < dep[v] ) swap(u , v) ;
modify(1 ,id[v], id[u], k) ;
}
···
查询两点之间路径的权值和
ll query_path(int u , int v ){
ll sum = 0 ;
while(top[u] != top[v]) {
if(dep[top[u]] < dep[top[v]]) swap(u , v) ;
sum += query(1 , id[top[u]] , id[u]) ;
u = fa[top[u]] ;
}
if(dep[u] < dep[v]) swap(u , v );
sum += query(1 , id[v] , id[u] ) ;
return sum ;
}
修改子树所有点的权值
void modify_tree(int u , int k ) {
modify(1 , id[u] , id[u] + sz[u] - 1 , k ) ;
}
查询子树的权值和
ll query_tree(int u) {
return query(1 , id[u] , id[u] + sz[u] - 1 ) ;
}
题
链接
代码
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5+10;
vector<int> ve[N];
int w[N];
int id[N] , nw[N] , cnt;
//id dfn序 , nw[id[i]]权值为i的权值w
int dep[N] , sz[N] , top[N] , fa[N] ,son[N];
//sz子树节点个数, top:重链的顶点, son:重儿子, fa:父节点
struct node{
int l , r;
int sum , flag;
}tr[N * 4 ];
void dfs1(int u , int v , int de){
dep[u] = de ; fa[u] = v ; sz[u] = 1;
for(auto t : ve[u]){
if(t == v) continue;
dfs1(t , u , de + 1);
sz[u] += sz[t];
if(sz[son[u]] < sz[t]) son[u] = t;
}
}
void dfs2(int u ,int v, int f){
id[u] = ++ cnt , nw[cnt] = w[u] , top[u] = f;
if(!son[u]) return;
dfs2(son[u] , u , f);
for(auto t : ve[u]){
if(t == v) continue;
if(t == son[u]) continue;
dfs2(t , u , t);
}
}
void push_up(int u){
tr[u].sum = tr[u << 1].sum + tr[u<< 1 | 1].sum;
}
void build(int u , int l ,int r){
tr[u] = {l , r };
if(l == r ){
tr[u] = {l , r , nw[l] , 0};
return;
}
int mid = l + r >> 1;
build(u << 1 , l , mid) , build(u << 1 | 1 , mid + 1 , r);
push_up(u);
}
void push_down(int u){
if(tr[u].flag != 0){
int &f = tr[u].flag;
node &l = tr[u << 1] , &r = tr[u << 1 | 1 ];
l.flag += f , r.flag += f;
l.sum += f * (l.r - l.l + 1 ) , r.sum += f*(r.r - r.l + 1 );
f = 0;
}
}
void modify(int u , int l , int r , int k){
if(tr[u].l >= l && tr[u].r <= r) {
tr[u].flag += k;
tr[u].sum += k * (tr[u].r - tr[u].l + 1 );
return;
}
push_down(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modify(u << 1 , l , r , k );
if(mid < r) modify(u << 1 | 1, l ,r , k);
push_up(u);
}
int query(int u , int l , int r){
if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
push_down(u);
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if(l <= mid ) sum += query(u << 1 , l , r);
if(mid < r ) sum += query(u << 1 | 1 , l , r);
return sum;
}
void modify_path(int u , int v , int k){
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u , v);
modify(1 , id[top[u]] , id[u] , k);
u = fa[top[u]];
}
if(dep[u] < dep[v]) swap(u , v);
modify(1 , id[v] , id[u] , k);
}
void modify_tree(int u , int k) {
modify(1 , id[u] , id[u] + sz[u] - 1 , k);
}
int query_tree(int u) {
return query(1 , id[u] , id[u] + sz[u] - 1);
}
int query_path(int u , int v){
int sum = 0;
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u , v);
sum += query(1 , id[top[u]] , id[u]);
u = fa[top[u]];
}
if(dep[u] < dep[v]) swap(u , v);
sum += query(1 , id[v] , id[u]);
return sum;
}
signed main(){
int n;
cin>>n;
for(int i = 1 ; i <= n; i ++) cin>>w[i];
for(int i = 1 ; i < n ; i ++) {
int a , b;
cin>>a>>b;
ve[a].push_back(b);
ve[b].push_back(a);
}
dfs1(1 , -1 , 1);
dfs2(1 , -1 , 1);
build(1 , 1 , n);
int q;
cin>>q;
while(q --){
int op;
cin>>op;
if(op == 1 ){
int u , v , k;
cin>>u>>v>>k;
modify_path(u , v , k);
}else if(op == 2){
int u , k;
cin>>u>>k;
modify_tree(u , k);
}else if(op == 3 ){
int u , v;
cin>>u>>v;
cout<<query_path(u , v)<<"\n";
}else{
int u;
cin>>u;
cout<<query_tree(u)<<"\n";
}
}
return 0 ;
}
(个人学习)