树链剖分:
即把树拆成若干条不相交的链, 分为三种:
一.重链剖分—常用, O(logn)
二.长链剖分—不常用,O(sqrt(n))
三.实链剖分—搞LCT(现在还不懂它是什么)
术语:
1.重儿子:一个节点的所有儿子中,大小最大的那个(最重的,所以说只有一个,
如果有多个儿子大相等那就随便取一个)
2.轻儿子:一个节点除了轻儿子以外的儿子都是轻儿子
3.重链:从轻儿子开始(根结点也是轻儿子),一路往重儿子走,出的链叫做重链
4.轻链:除了重链全是轻链
注意: 这里用到dfs序, dfs标时间戳的时候优先往重儿子走
所以需要两次dfs 第一次dfs先得标出重儿子这样才能进行上面所说的dfs序。
拿个题做做:AcWing1278
点击这里
问题就是说给你一颗树,树上每个节点都有一个权值,让你完成三个操作:
1.单点修改(线段树的基础操作);
2.询问两节点间的权值最大和(这就需要把树拆分成链在建线段树,对于查询操作会与线段树有所不同, 具体在后面会写);
3.询问两节点间的权值和(和操作二一样)。
操作都包括节点。
需要的有:重儿子:son[N]; 节点父 fa[N];
节点深度dep[N]; 儿子大小num[N].
首先看一下第一遍dfs:
const int N = 1e5 + 7;
const int INF = 1e9 + 7;
int son[N], fa[N];
int dep[N], num[N];
vector<int> G[N];
void dfs1(int s, int f)
{
num[s] = 1;
for (int i = 0; i < G[s].size(); i ++ )
{
int v = G[s][i];
if (v == f) continue; // 如果是父亲节点直接跳过
fa[v] = s;
dep[v] = dep[s] + 1; // 儿子节点 深度加1
dfs1(v,s);
num[s] += num[v]; // 当前节点加上儿子节点大小(这里的大小是说儿子的个数)
if (num[v] > num[son[s]]) son[s] = v; // 更新重儿子
}
}
标记出重儿子 就可以进行dfs序了:
int pre[N], fpre[N]; // pre 记录dfs序 fpre记录dfs序下的节点
int top[N]; // 重链顶端
int cnt;// 时间戳
void dfs2(int s, int tp) // dfs序 tp 为重链的顶端
{
pre[s] = ++ cnt;
fpre[cnt] = s;
top[s] = tp; // 重链顶
if (son[s]) dfs2(son[s], tp); // 有重儿子先走重儿子
for (int i = 0; i < G[s].size(); i ++ )
{
int v = G[s][i];
if(v == fa[s] || v == son[s]) continue;
dfs2(v, v); // 轻链的重链顶是自己
}
}
接下来就是建树操作了:
struct Node{
int l, r, maxv, num;
int lazy;
}T[N<<2];
void up(int k)
{
T[k].maxv = max(T[k << 1].maxv, T[k << 1 | 1].maxv);
T[k].num = T[k << 1].num + T[k << 1 | 1].num;
}
void insert(int l, int r, int k = 1)
{
T[k].l = l;
T[k].r = r;
T[k].lazy = 0;
if (l == r)
{
T[k].num = T[k].maxv = arr[fpre[l]]; // 是不是这个arr数组有点懵了 想想fpre记录的是什么
return ;
}
int mid = l + r >> 1;
insert(l, mid, k << 1);
insert(mid + 1, r, k << 1 | 1);
up(k);
}
修改操作 :与线段树一样:
void modify(int pos, int val, int k = 1)
{
if (T[k].l == T[k].r)
{
T[k].maxv = T[k].num = val;
return ;
}
int mid = T[k].l + T[k].r >> 1;
if (pos <= mid) modify(pos, val, k << 1);
else modify(pos, val, k << 1 | 1);
up(k);
}
不一样的查询操作:
int query(int l, int r, int k = 1)
{
if (l <= T[k].l && r >= T[k].r) return T[k].num;
int ans = 0;
int mid = T[k].l + T[k].r >> 1;
if (l <= mid) ans += query(l, r, k << 1);
if (r > mid) ans += query(l, r, k << 1 | 1);
return ans;
}
int querySum(int L, int R) // 我在是不一样的
{
int ans = 0;
while (top[L] != top[R]) // 因为我们要查询的树上是两节点的值,所以查询路径就会有所改变, 拿上面的图举例
//比如要查询6到8我们先查询8到4,在查询6到1。
//因为 我们的线段树是通过dfs序建立的 所以需要查询的是每条链1到6为一个链 8到4为一个链
{
if (dep[top[L]] < dep[top[R]]) swap(L, R); // 这里是为了方便 所以先查询节点深度大的
ans += query(pre[top[L]], pre[L]);
L = fa[top[L]];
}
if (dep[L] < dep[R]) swap(L, R);
ans += query(pre[R], pre[L]);
return ans;
}
最大值操作与查询一样。
剩下的也没什么了 完整代码了解一下:
void dfs1(int s, int f)
{
num[s] = 1;
for (int i = 0; i < G[s].size(); i ++ )
{
int v = G[s][i];
if (v == f) continue;
fa[v] = s;
dep[v] = dep[s] + 1;
dfs1(v,s);
num[s] += num[v];
if (num[v] > num[son[s]]) son[s] = v; // 更新重儿子
}
}
int pre[N], fpre[N];
int top[N]; //
int cnt;
void dfs2(int s, int tp) // dfs序
{
pre[s] = ++ cnt;
fpre[cnt] = s;
top[s] = tp; // 重链顶
if (son[s]) dfs2(son[s], tp); // 有重儿子先走重儿子
for (int i = 0; i < G[s].size(); i ++ )
{
int v = G[s][i];
if(v == fa[s] || v == son[s]) continue;
dfs2(v, v); // 轻链的重链顶是自己
}
}
struct Node{
int l, r, maxv, num;
int lazy;
}T[N<<2];
void up(int k)
{
T[k].maxv = max(T[k << 1].maxv, T[k << 1 | 1].maxv);
T[k].num = T[k << 1].num + T[k << 1 | 1].num;
}
void insert(int l, int r, int k = 1)
{
T[k].l = l;
T[k].r = r;
T[k].lazy = 0;
if (l == r)
{
T[k].num = T[k].maxv = arr[fpre[l]];
return ;
}
int mid = l + r >> 1;
insert(l, mid, k << 1);
insert(mid + 1, r, k << 1 | 1);
up(k);
}
void modify(int pos, int val, int k = 1)
{
if (T[k].l == T[k].r)
{
T[k].maxv = T[k].num = val;
return ;
}
int mid = T[k].l + T[k].r >> 1;
if (pos <= mid) modify(pos, val, k << 1);
else modify(pos, val, k << 1 | 1);
up(k);
}
int query(int l, int r, int k = 1)
{
if (l <= T[k].l && r >= T[k].r)
{
return T[k].num;
}
int ans = 0;
int mid = T[k].l + T[k].r >> 1;
if (l <= mid) ans += query(l, r, k << 1);
if (r > mid) ans += query(l, r, k << 1 | 1);
return ans;
}
int querySum(int L, int R)
{
int ans = 0;
while (top[L] != top[R])
{
if (dep[top[L]] < dep[top[R]]) swap(L, R);
ans += query(pre[top[L]], pre[L]);
L = fa[top[L]];
}
if (dep[L] < dep[R]) swap(L, R);
ans += query(pre[R], pre[L]);
return ans;
}
int query1(int l, int r, int k = 1)
{
if (l <= T[k].l && r >= T[k].r)
{
return T[k].maxv;
}
int ans = -INF, ans2 = -INF;
int mid =T[k].l + T[k].r >> 1;
if (l <= mid) ans = max(ans, query1(l, r, k << 1));
if (r > mid) ans = max(ans, query1(l, r, k << 1 | 1));
return ans;
}
int queryMax(int L, int R)
{
int ans = -INF;
while (top[L] != top[R])
{
if (dep[top[L]] < dep[top[R]]) swap(L, R);
ans = max(ans, query1(pre[top[L]], pre[L]));
L = fa[top[L]];
}
if (dep[L] < dep[R]) swap(L, R);
ans = max(ans, query1(pre[R], pre[L]));
return ans;
}
int main ()
{
memset(arr, 0, sizeof arr);
int n, m, a, b;
scanf ("%d", &n);
for (int i = 0; i <= n; i ++ ) G[i].clear();
for (int i = 1; i < n; i ++ )
{
scanf ("%d%d", &a, &b);
G[a].push_back(b);
G[b].push_back(a);
}
for (int i = 1; i <= n; i ++ ) scanf ("%d", &arr[i]);
dep[1] = 1, fa[1] = 1;
dfs1(1, -1);
dfs2(1, 1);
insert(1, n);
scanf ("%d", &m);
for (int i = 0; i < m; i ++ )
{
char str[20];
scanf ("%s%d%d", str, &a, &b);
if (str[1] == 'M') printf ("%d\n", queryMax(a, b));
else if (str[1] == 'S') printf ("%d\n", querySum(a, b));
else modify(pre[a], b); // 这个地方我找了 好久好久 (自闭ing)
}
return 0;
}
Fighting