改编自:
https://blog.csdn.net/sdut16szq/article/details/79148096
http://www.cnblogs.com/KingSann/articles/9441763.html
算法:
LCA,DFS序,线段树区间维护
解题思路:
首先考虑不对根进行修改,通过对根与询问节点的相对位置关系分析,通过求LCA实现各个操作。由于每一次操作修改的点不确定,最差情况下暴力修改会达到O(n2n2),所以自然而然的考虑线段树维护。
然后我们先做好LCA算法的准备工作,以及通过DFS序建立区间更新的线段树。
接着分析每一个操作:
操作1:直接将root标记成x即可。
操作2:仔细分析根和x,y的可能存在位置关系,可能三种情况,分别分析很容易得出求解方法
仔细想一下,什么时候换了根的树的答案和原树不同?
对于修改操作:
如果lca(u,v)(i.e lca(x,y))是当前root的祖先,那么计算子树和的时候会同时少计算和多计算。
其他情况查询答案和原树相同
对于查询操作,少了一个参数,一样的道理
一种是x是r祖先 那么和原树不同
否则 相同
几个技巧:
a是b的祖先lca《==》(a,b)= a
dfs序:点x为根的字数的所有节点的dfs序是连续的,为first【x】~last【x】,在dfs中实现first和last数组
代码:(cf坏了,过会测)—————— 果然炸了
#include<bits/stdc++.h>
#define ll long long
using namespace std;
#define dprintf if (debug) printf
#define rep(i, j, k) for (int i=j; i<k; i++)
const int maxn = 1e5+50;
const int debug = 0;
int fa[maxn][24], fr, to, edges[maxn*2], pre[maxn*2], tail[maxn],
ql, qr, first[maxn], last[maxn], depth[maxn], n, q, a[maxn], luv, luvr, lrv,
undo, step, op, x, cnt, root, u, v, f[maxn], lur, lvr;
ll pls, mi, sum[maxn*3], lazy[maxn*3];
void addEdge(int fr, int to){
edges[++cnt] = to;
pre[cnt] = tail[fr];
tail[fr] = cnt;
}
void pushdown(int o, int l, int r){
int mid = l+r>>1;
sum[o<<1] += lazy[o] * (mid - l + 1);
sum[o<<1|1] += lazy[o] * (r - mid);
lazy[o<<1] += lazy[o]; lazy[o<<1|1] += lazy[o];
lazy[o] = 0;
}
void update(int o, int l, int r){
//dprintf("update %d %d %d %d %d %d\n", o, l, r, ql, qr, x);
if (ql<=l && qr>=r){
sum[o]+=x*(r-l+1);
lazy[o]+=x;
//dprintf("sum[%d, %d] = %lld lazy = %lld\n", l, r, sum[o], lazy[o]);
return;
}
pushdown(o, l, r);
int mid = l+r>>1;
if (ql <= mid) update(o<<1, l, mid);
if (qr >= mid+1) update(o<<1|1, mid+1, r);
sum[o] = sum[o<<1] + sum[o<<1|1];
//dprintf("sum[%d, %d] = %lld lazy = %lld\n", l, r, sum[o], lazy[o]);
}
ll query(int o, int l, int r){
//dprintf("query %d %d %d %d %d\n", o, l, r, ql, qr);
if (ql <= l && qr >= r){
return sum[o];
}
pushdown(o, l, r);
int mid = l+r>>1;
ll lsum = 0, rsum = 0;
if (ql <= mid)
lsum = query(o<<1, l, mid);
if (qr >= mid+1)
rsum = query(o<<1|1, mid+1, r);
return lsum + rsum;
}
void dfs(int u, int layer){
//dprintf("dfs %d %d\n", u, layer);
depth[u] = layer;
first[u] = ++step;
f[step] = u;
for (int i=tail[u]; i; i=pre[i]){
int to = edges[i];
if (to != fa[u][0]) {
fa[to][0] = u;
dfs(to, layer+1);
}
}
last[u] = step;
}
void makelca(){
rep(i, 1, n+1){
rep(j, 1, 23){
fa[i][j] = fa[fa[i][j-1]][j-1];
}
}
}
int lca(int x, int y){
if (depth[x] < depth[y]) swap(x, y);
for (int i=22; i>=0; i--){
if (depth[fa[x][i]] >= depth[y]) x = fa[x][i];
}
if (x == y) return x;
for (int i=22; i>=0; i--){
if (fa[x][i] != fa[y][i]){
x = fa[x][i]; y = fa[y][i];
}
}
return fa[x][0];
}
int lst(int son, int father){
for (int i=22; i>=0; i--){
if (depth[fa[son][i]] > depth[father]){
son = fa[son][i];
}
}
return son;
}
int main(){
scanf("%d%d", &n, &q);
rep(i, 1, n+1){
scanf("%d", &a[i]);
}
rep(i, 0, n-1){
scanf("%d%d", &fr, &to);
addEdge(fr, to);
addEdge(to, fr);
}
dfs(1, 1);
rep(i, 1, n+1){
ql = first[i]; qr = first[i]; x = a[i];
update(1, 1, n);
}
dprintf("\n\n\n\n\n");
makelca();
rep(i, 0, q){
dprintf("\n\n\n\n\n");
scanf("%d", &op);
if (op == 1){
scanf("%d", &root);
}
else if (op == 2){
scanf("%d%d%d", &u, &v, &x);
luv = lca(u, v); luvr = lca(luv, root);
if (luv == luvr){
ql = 1; qr = n;
dprintf("update subtree of 1 x = %d\n", x);
update(1, 1, n);
if (depth[luv] < depth[root]){
lur = lca(u, root); lvr = lca(v, root);
if (depth[lur] > depth[lvr]) undo = lst(root, lur);
else undo = lst(root, lvr);
ql = first[undo]; qr = last[undo]; x = -x;
dprintf("update subtree of %d x = %d\n", undo, x);
update(1, 1, n);
}
}
else{
dprintf("update subtree of %d x = %d\n", luv, x);
ql = first[luv]; qr = last[luv];
update(1, 1, n);
}
}
else{
scanf("%d", &v);
lrv = lca(root, v);
if (v == root){
ql = 1; qr = n;
printf("%lld\n", query(1, 1, n));
}
else if (lrv != v){
ql = first[v]; qr = last[v];
printf("%lld\n", query(1, 1, n));
}
else {
ql = 1; qr = n;
pls = query(1, 1, n);
undo = lst(root, lrv);
ql = first[undo]; qr = last[undo];
mi = query(1, 1, n);
printf("%lld\n", pls - mi);
}
}
}
}
AC代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
#define dprintf if (debug) printf
#define rep(i, j, k) for (int i=j; i<k; i++)
const int maxn = 1e5+50;
const int debug = 0;
int fa[maxn][24], fr, to, edges[maxn*2], pre[maxn*2], tail[maxn],
ql, qr, first[maxn], last[maxn], depth[maxn], n, q, a[maxn], luv, luvr, lrv,
undo, step, op, x, cnt, root, u, v, f[maxn], lur, lvr;
ll pls, mi, sum[maxn*3], lazy[maxn*3];
void addEdge(int fr, int to){
edges[++cnt] = to;
pre[cnt] = tail[fr];
tail[fr] = cnt;
}
void pushdown(int o, int l, int r){
int mid = l+r>>1;
sum[o<<1] += lazy[o] * (mid - l + 1);
sum[o<<1|1] += lazy[o] * (r - mid);
lazy[o<<1] += lazy[o]; lazy[o<<1|1] += lazy[o];
lazy[o] = 0;
}
void update(int o, int l, int r){
//dprintf("update %d %d %d %d %d %d\n", o, l, r, ql, qr, x);
if (ql<=l && qr>=r){
sum[o]+=(ll) x*(r-l+1); ///呵呵
lazy[o]+=x;
//dprintf("sum[%d, %d] = %lld lazy = %lld\n", l, r, sum[o], lazy[o]);
return;
}
pushdown(o, l, r);
int mid = l+r>>1;
if (ql <= mid) update(o<<1, l, mid);
if (qr >= mid+1) update(o<<1|1, mid+1, r);
sum[o] = sum[o<<1] + sum[o<<1|1];
//dprintf("sum[%d, %d] = %lld lazy = %lld\n", l, r, sum[o], lazy[o]);
}
ll query(int o, int l, int r){
//dprintf("query %d %d %d %d %d\n", o, l, r, ql, qr);
if (ql <= l && qr >= r){
return sum[o];
}
pushdown(o, l, r);
int mid = l+r>>1;
ll lsum = 0, rsum = 0;
if (ql <= mid)
lsum = query(o<<1, l, mid);
if (qr >= mid+1)
rsum = query(o<<1|1, mid+1, r);
return lsum + rsum;
}
void dfs(int u, int layer){
//dprintf("dfs %d %d\n", u, layer);
depth[u] = layer;
first[u] = ++step;
f[step] = u;
for (int i=tail[u]; i; i=pre[i]){
int to = edges[i];
if (to != fa[u][0]) {
fa[to][0] = u;
dfs(to, layer+1);
}
}
last[u] = step;
}
void makelca(){
rep(j, 1, 23){
rep(i, 1, n+1){
fa[i][j] = fa[fa[i][j-1]][j-1];
}
}
}
int lca(int x, int y){
if (depth[x] < depth[y]) swap(x, y);
for (int i=22; i>=0; i--){
if (depth[fa[x][i]] >= depth[y]) x = fa[x][i];
}
if (x == y) return x;
for (int i=22; i>=0; i--){
if (fa[x][i] != fa[y][i]){
x = fa[x][i]; y = fa[y][i];
}
}
return fa[x][0];
}
int lst(int son, int father){
for (int i=22; i>=0; i--){
if (depth[fa[son][i]] > depth[father]){
son = fa[son][i];
}
}
return son;
}
int main(){
scanf("%d%d", &n, &q);
rep(i, 1, n+1){
scanf("%d", &a[i]);
}
rep(i, 0, n-1){
scanf("%d%d", &fr, &to);
addEdge(fr, to);
addEdge(to, fr);
}
dfs(1, 1);
rep(i, 1, n+1){
ql = first[i]; qr = first[i]; x = a[i];
update(1, 1, n);
}
dprintf("\n\n\n\n\n");
makelca();
rep(i, 0, q){
dprintf("\n\n\n\n\n");
scanf("%d", &op);
if (op == 1){
scanf("%d", &root);
}
else if (op == 2){
scanf("%d%d%d", &u, &v, &x);
luv = lca(u, v); luvr = lca(luv, root);
dprintf("luv = %d luvr = %d\n", luv, luvr);
if (luv == luvr){
ql = 1; qr = n;
dprintf("update subtree of 1 x = %d\n", x);
update(1, 1, n);
lur = lca(u, root); lvr = lca(v, root);
if (lur != root && lvr != root){
dprintf("lur = %d lvr = %d\n", lur, lvr);
undo = lur == luvr? lvr : lur;
undo = lst(root, undo);
ql = first[undo]; qr = last[undo]; x = -x;
dprintf("update subtree of %d x = %d\n", undo, x);
update(1, 1, n);
}
}
else{
dprintf("update subtree of %d x = %d\n", luv, x);
ql = first[luv]; qr = last[luv];
update(1, 1, n);
}
}
else{
scanf("%d", &v);
lrv = lca(root, v);
if (v == root){
ql = 1; qr = n;
printf("%lld\n", query(1, 1, n));
}
else if (lrv != v){
ql = first[v]; qr = last[v];
printf("%lld\n", query(1, 1, n));
}
else {
ql = 1; qr = n;
pls = query(1, 1, n);
undo = lst(root, lrv);
ql = first[undo]; qr = last[undo];
mi = query(1, 1, n);
printf("%lld\n", pls - mi);
}
}
}
}