虚树的概念
虚树指的是在原树中抽出包含某特定几个节点且保留原树结构的最小的树。
虚树解决的问题
用来解决多组询问的树形
D
P
DP
DP 问题,每次询问会给出一个点集作为关键点,假设点集大小为
k
k
k,则
∑
k
\sum k
∑k 和
n
n
n 同阶。而树形
d
p
dp
dp 的转移往往可以只用关键点来解决,也就是很多非关键点的作用不是必要的。假设给出大小为
k
k
k 的点集,我们就利用这些关键点建出一棵大小为
O
(
k
)
O(k)
O(k) 的树,这棵树的结构和原树一致,然后在这颗树上进行
d
p
dp
dp。
建树的复杂度是
O
(
k
l
o
g
k
)
O(klogk)
O(klogk) 的,
d
p
dp
dp 的复杂度可以是
O
(
k
)
O(k)
O(k),也可以是
O
(
k
l
o
g
k
)
O(klogk)
O(klogk)。
举个例子
原树
关键点为 {2,4} 的虚树
关键点为 {4,5} 的虚树
关键点为 {4,6} 的虚树
关键点为 {4,6,7} 的虚树
虚树的构建
我们不妨把
1
1
1 作为虚树的根节点。
由于虚树的结构要和原树一致,我们基于原树的
d
f
s
dfs
dfs 序来完成虚树的构建。
算法流程为:
- 将关键点按 d f s dfs dfs 序排序
- 先将
1
1
1 入栈,从
d
f
s
dfs
dfs 序小到大添加节点,我们这时候的栈维护的其实是树上的一条链,然后我们分两种情况讨论:
①:栈顶节点和当前节点的 l c a lca lca 为栈顶节点,那么直接将当前节点入栈即可
②:栈顶节点和当前节点的 l c a lca lca 不是栈顶节点,如图:
这时候,栈维护的链是:
而我们需要把链变成这样:
因此我们把蓝色节点弹栈,并向虚树中的父亲连边:
弹完栈后,如果栈顶元素不是 l c a lca lca,应该把 l c a lca lca 入栈。
建树的代码
vector<int> G[N];
void build_virtual_tree(int k){
sta[top = 1] = 1; G[1].clear();//先把 1 入栈,并清空 1 的连边
for(int i = 1; i <= k; i++){
if(h[i] != 1){//如果是 1 就没必要重复进栈了
int Lca = lca(sta[top], h[i]);//获得栈顶元素和当前元素的 lca
if(Lca != sta[top]){//如果 lca 不是栈顶元素,即应该换一条链,应该不断弹栈
while(dfn[sta[top - 1]] > dfn[Lca]) G[sta[top - 1]].push_back(sta[top]), top--;//不断连边
if(dfn[Lca] > dfn[sta[top - 1]]) G[Lca].clear(), G[Lca].push_back(sta[top]), sta[top] = Lca;//lca的dfs序大于次大元素,说明lca从未入栈,则清空lca的连边,将lca和栈顶连边并将lca入栈
else G[Lca].push_back(sta[top]), top--;//如果 lca 入过栈了,直接将 lca 和栈顶连边
}
G[h[i]].clear(), sta[++top] = h[i];//将当前元素入栈
}
}
while(top > 1) G[sta[top - 1]].push_back(sta[top]), top--;//将最后一条链连边
}
[SDOI2011] 消耗战
给出 n n n 个点的一棵带有边权的树,以及 q q q 个询问.每次询问给出 k k k 个点,询问这使得这 k k k 个点与 1 1 1 点不连通所需切断的边的边权和最小是多少。
其中, n ≤ 2.5 × 1 0 5 , q ≤ 5 × 1 0 5 , ∑ k ≤ 5 × 1 0 5 , w i ≤ 1 0 5 n\le2.5\times 10^5,q\le 5\times 10^5,\sum k\le 5\times 10^5,w_i\le 10^5 n≤2.5×105,q≤5×105,∑k≤5×105,wi≤105。
如果不建虚树,每次考虑
d
p
dp
dp,设
d
p
u
dp_u
dpu 表示
u
u
u 的子树中的点全部和
1
1
1 断掉联系的答案。
令
v
a
l
i
val_i
vali 为
i
i
i 到
1
1
1 路径上边权的最小值。那么转移的时候分两种情况:
-
u
u
u 是关键点:
d p u = v a l u dp_u=val_u dpu=valu -
u
u
u 不是关键点:
d p u = min { v a l u , ∑ d p v } dp_u=\min\{val_u,\sum dp_v\} dpu=min{valu,∑dpv}
建虚树后按照同样方式转移即可。总复杂度为
O
(
(
∑
k
)
l
o
g
(
∑
k
)
)
O((\sum k)log(\sum k))
O((∑k)log(∑k))
代码如下。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define pii pair<int, int>
const int N = 3e5 + 5;
vector<pii> E[N];
int st[N][20], dfn[N], dep[N], h[N], is[N], sta[N], top, dft;
LL dp[N], val[N];
void dfs(int u, int pre){
dfn[u] = ++dft;
for(auto& e: E[u]){
int v, w;
tie(v, w) = e;
if(v == pre) continue;
dep[v] = dep[u] + 1;
st[v][0] = u;
val[v] = min(val[u], (LL)w);
dfs(v, u);
}
}
void build_st(int n){
for(int i = 1; i <= 19; i++){
for(int j = 1; j <= n; j++) st[j][i] = st[st[j][i - 1]][i - 1];
}
}
int lca(int u, int v){
if(dep[u] < dep[v]) swap(u, v);
int d = dep[u] - dep[v];
for(int i = 19; i >= 0; i--) if(d >> i & 1) u = st[u][i];
if(u == v) return u;
for(int i = 19; i >= 0; i--) if(st[u][i] != st[v][i]) u = st[u][i], v = st[v][i];
return st[u][0];
}
vector<int> G[N];
void build_virtual_tree(int k){
sta[top = 1] = 1; G[1].clear();
for(int i = 1; i <= k; i++){
if(h[i] != 1){
int Lca = lca(sta[top], h[i]);
if(Lca != sta[top]){
while(dfn[sta[top - 1]] > dfn[Lca]) G[sta[top - 1]].push_back(sta[top]), top--;
if(dfn[Lca] > dfn[sta[top - 1]]) G[Lca].clear(), G[Lca].push_back(sta[top]), sta[top] = Lca;
else G[Lca].push_back(sta[top]), top--;
}
G[h[i]].clear(), sta[++top] = h[i];
}
}
while(top > 1) G[sta[top - 1]].push_back(sta[top]), top--;
}
void dfs1(int u, int pre){
LL sum = 0;
dp[u] = 0;//注意每次的 dp 数组要初始化,由于只会用到虚树上的点,对这些点 dfs 的时候初始化即可
for(int& v: G[u]){
if(v == pre) continue;
dfs1(v, u);
sum += dp[v];
}
if(is[u]) dp[u] = val[u];
else dp[u] = min(sum, (LL)val[u]);
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int n;
cin >> n;
for(int i = 1, u, v, w; i < n; i++){
cin >> u >> v >> w;
E[u].emplace_back(v, w);
E[v].emplace_back(u, w);
}
val[1] = 1e18;
dfs(1, 0);
build_st(n);
int m;
cin >> m;
while(m--){
int k;
cin >> k;
for(int i = 1; i <= k; i++) cin >> h[i], is[h[i]] = 1;
sort(h + 1, h + k + 1, [&](int x, int y){
return dfn[x] < dfn[y];
}); //按照 dfs 序排序
build_virtual_tree(k);//建虚树
dfs1(1, 0);//在虚树上 dp
cout << dp[1] << '\n';
for(int i = 1; i <= k; i++) is[h[i]] = 0;//把关键点清空
}
return 0;
}
[CF613D] Kingdom and its Cities
给出一棵 n n n 个点的树和 m m m 次询问,每次询问给出 k k k 个点,问最少删掉树上多少个点,使得这 k k k 个点互不连通,无解输出 − 1 -1 −1。
n , m , ∑ k ≤ 1 0 5 n,m,\sum k\le 10^5 n,m,∑k≤105。
如果存在两个关键点相邻,则肯定无解,否则有解。
然后套路建出虚树,设
d
p
u
dp_u
dpu 表示以
u
u
u 为根的子树中关键点两两不连通的最小代价,再用一个
s
z
u
sz_u
szu 表示
u
u
u 的子树中有多少个关键点向
u
u
u 延伸。
那么,首先有:
d
p
u
=
∑
v
∈
s
o
n
(
u
)
d
p
v
dp_u=\sum\limits_{v\in son(u)} dp_v
dpu=v∈son(u)∑dpv
如果
u
u
u 是关键点,那么
u
u
u 必须和所有的
s
z
u
sz_u
szu 断开,即:
d
p
u
+
=
s
z
u
,
s
z
u
=
1
dp_u+=sz_u,sz_u=1
dpu+=szu,szu=1
如果
u
u
u 不是关键点,如果
s
z
u
>
1
sz_u>1
szu>1,那么必须在
u
u
u 处断开,即
d
p
u
+
=
1
,
s
z
u
=
0
dp_u+=1,sz_u=0
dpu+=1,szu=0
代码如下。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define pii pair<int, int>
const int N = 3e5 + 5;
vector<int> E[N];
int st[N][20], dfn[N], dep[N], h[N], is[N], sta[N], top, dft, dp[N], sz[N], flag;
void dfs(int u, int pre){
dfn[u] = ++dft;
for(int& v: E[u]){
if(v == pre) continue;
dep[v] = dep[u] + 1;
st[v][0] = u;
dfs(v, u);
}
}
void build_st(int n){
for(int i = 1; i <= 19; i++){
for(int j = 1; j <= n; j++) st[j][i] = st[st[j][i - 1]][i - 1];
}
}
int lca(int u, int v){
if(dep[u] < dep[v]) swap(u, v);
int d = dep[u] - dep[v];
for(int i = 19; i >= 0; i--) if(d >> i & 1) u = st[u][i];
if(u == v) return u;
for(int i = 19; i >= 0; i--) if(st[u][i] != st[v][i]) u = st[u][i], v = st[v][i];
return st[u][0];
}
vector<int> G[N];
void build_virtual_tree(int k){
sta[top = 1] = 1; G[1].clear();//先把 1 入栈,并清空 1 的连边
for(int i = 1; i <= k; i++){
if(h[i] != 1){//如果是 1 就没必要重复进栈了
int Lca = lca(sta[top], h[i]);//获得栈顶元素和当前元素的 lca
if(Lca != sta[top]){//如果 lca 不是栈顶元素,即应该换一条链,应该不断弹栈
while(dfn[sta[top - 1]] > dfn[Lca]) G[sta[top - 1]].push_back(sta[top]), top--;//不断连边
if(dfn[Lca] > dfn[sta[top - 1]]) G[Lca].clear(), G[Lca].push_back(sta[top]), sta[top] = Lca;//lca的dfs序大于次大元素,说明lca从未入栈,则清空lca的连边,将lca和栈顶连边并将lca入栈
else G[Lca].push_back(sta[top]), top--;//如果 lca 入过栈了,直接将 lca 和栈顶连边
}
G[h[i]].clear(), sta[++top] = h[i];//将当前元素入栈
}
}
while(top > 1) G[sta[top - 1]].push_back(sta[top]), top--;//将最后一条链连边
}
void dfs1(int u, int pre){
dp[u] = sz[u] = 0;
for(int& v: G[u]){
if(v == pre) continue;
dfs1(v, u);
if(is[u] && is[v] && u == st[v][0]) flag = 1;
sz[u] += sz[v];
dp[u] += dp[v];
}
if(is[u]) dp[u] += sz[u], sz[u] = 1;
else{
if(sz[u] > 1) dp[u]++, sz[u] = 0;
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int n;
cin >> n;
for(int i = 1, u, v; i < n; i++){
cin >> u >> v;
E[u].push_back(v);
E[v].push_back(u);
}
dfs(1, 0);
build_st(n);
int m;
cin >> m;
while(m--){
int k;
cin >> k;
for(int i = 1; i <= k; i++) cin >> h[i], is[h[i]] = 1;
sort(h + 1, h + k + 1, [&](int x, int y){
return dfn[x] < dfn[y];
});
build_virtual_tree(k);
flag = 0;
dfs1(1, 0);
for(int i = 1; i <= k; i++) is[h[i]] = 0;
if(flag) cout << -1 << '\n';
else cout << dp[1] << '\n';
}
return 0;
}
[HNOI2014] 世界树
给出一棵 n n n 个点的树和 m m m 次询问,每次询问给出 k k k 个点作为关键点,树上的每个点由最近的关键点控制,问每个关键点会控制多少个点。
n , m , k ≤ 3 × 1 0 5 n,m,k\le 3\times 10^5 n,m,k≤3×105。
套路建出虚树。
先两次
d
f
s
dfs
dfs 求出虚树上每个点由哪个点控制。
然后再用一次
d
f
s
dfs
dfs 来
d
p
dp
dp,记
o
c
u
oc_u
ocu 表示控制
u
u
u 的关键点。
一开始初始化
d
p
o
c
u
=
s
z
u
dp_{oc_u}=sz_u
dpocu=szu
如果
o
c
u
=
o
c
v
oc_u=oc_v
ocu=ocv,那么
d
p
o
c
u
=
d
p
o
c
u
−
s
z
v
dp_{oc_u}=dp_{oc_u}-sz_v
dpocu=dpocu−szv,因为在子树中已经统计过了。
如果
o
c
u
≠
o
c
v
oc_u\neq oc_v
ocu=ocv,那么我们应该从
v
v
v 往上找到最后一点
p
p
p,满足
p
p
p 由
o
c
v
oc_v
ocv 控制,
f
a
p
fa_p
fap 由
o
c
u
oc_u
ocu 控制。找
p
p
p 可以通过倍增求出。然后
d
p
o
c
u
=
d
p
o
c
u
−
s
z
p
,
d
p
o
c
v
=
d
p
o
c
v
+
s
z
p
−
s
z
x
dp_{oc_u}=dp_{oc_u}-sz_p,dp_{oc_v}=dp_{oc_v}+sz_p-sz_x
dpocu=dpocu−szp,dpocv=dpocv+szp−szx。
细节颇多,建议读者自己手写一番。。。
代码如下。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define pii pair<int, int>
const int N = 3e5 + 5;
vector<int> E[N];
int st[N][20], dfn[N], dep[N], h[N], re[N], is[N], sta[N], top, dft, sz[N], oc[N];
LL dp[N];
void dfs(int u, int pre){
dfn[u] = ++dft; sz[u] = 1;
for(int& v: E[u]){
if(v == pre) continue;
dep[v] = dep[u] + 1;
st[v][0] = u;
dfs(v, u);
sz[u] += sz[v];
}
}
void build_st(int n){
for(int i = 1; i <= 19; i++){
for(int j = 1; j <= n; j++) st[j][i] = st[st[j][i - 1]][i - 1];
}
}
int lca(int u, int v){
if(dep[u] < dep[v]) swap(u, v);
int d = dep[u] - dep[v];
for(int i = 19; i >= 0; i--) if(d >> i & 1) u = st[u][i];
if(u == v) return u;
for(int i = 19; i >= 0; i--) if(st[u][i] != st[v][i]) u = st[u][i], v = st[v][i];
return st[u][0];
}
int getd(int u, int d){
for(int i = 19; i >= 0; i--) if(d >> i & 1) u = st[u][i];
return u;
}
vector<int> G[N];
void build_virtual_tree(int k){
sta[top = 1] = 1; G[1].clear();
for(int i = 1; i <= k; i++){
if(h[i] != 1){
int Lca = lca(sta[top], h[i]);
if(Lca != sta[top]){
while(dfn[sta[top - 1]] > dfn[Lca]) G[sta[top - 1]].push_back(sta[top]), top--;
if(dfn[Lca] > dfn[sta[top - 1]]) G[Lca].clear(), G[Lca].push_back(sta[top]), sta[top] = Lca;
else G[Lca].push_back(sta[top]), top--;
}
G[h[i]].clear(), sta[++top] = h[i];
}
}
while(top > 1) G[sta[top - 1]].push_back(sta[top]), top--;
}
int getdis(int u, int v){
return dep[u] + dep[v] - 2 * dep[lca(u, v)];
}
void dfs1(int u, int pre){
oc[u] = 0;
dp[u] = 0;
for(int& v: G[u]){
if(v == pre) continue;
dfs1(v, u);
int x = getdis(oc[u], u), y = getdis(oc[v], u);
if(!oc[u] || y < x || (y == x && oc[v] < oc[u])) oc[u] = oc[v];
}
if(is[u]) oc[u] = u;
}
void dfs2(int u, int pre){
for(int& v: G[u]){
if(v == pre) continue;
int x = getdis(oc[v], v), y = getdis(oc[u], v);
if(!oc[v] || y < x || (y == x && oc[u] < oc[v])) oc[v] = oc[u];
dfs2(v, u);
}
}
void dfs3(int u, int pre){
dp[oc[u]] += sz[u];
for(int& v: G[u]){
if(v == pre) continue;
dfs3(v, u);
if(oc[u] == oc[v]) dp[oc[u]] -= sz[v];
else{
int x = getdis(oc[u], u), y = getdis(oc[v], v), w = dep[v] - dep[u] - 1;
int d = w + x - y, k;
k = d / 2;
if(d > 0 && d % 2 && oc[v] < oc[u]) k++;
int p = getd(v, k);
dp[oc[v]] += sz[p] - sz[v];
dp[oc[u]] -= sz[p];
}
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int n;
cin >> n;
for(int i = 1, u, v; i < n; i++){
cin >> u >> v;
E[u].push_back(v);
E[v].push_back(u);
}
dfs(1, 0);
build_st(n);
int m;
cin >> m;
while(m--){
int k;
cin >> k;
for(int i = 1; i <= k; i++) cin >> h[i], re[i] = h[i], is[h[i]] = 1;
sort(h + 1, h + k + 1, [&](int x, int y){
return dfn[x] < dfn[y];
});
build_virtual_tree(k);
dfs1(1, 0);
dfs2(1, 0);
dfs3(1, 0);
for(int i = 1; i <= k; i++) is[h[i]] = 0;
for(int i = 1; i <= k; i++) cout << dp[re[i]] << ' ';
cout << '\n';
}
return 0;
}