虚树+LCA
新学虚树,在此记录一下。
首先得搞明白什么是虚树
虚树经常被使用在树形dp中,往往会给你一颗相对复杂的树。然后给你一系列的查询条件,即包含一系列的查询关键点,然后让你在这些少量的关键点中进行查询遍历得出你想要的结果。而此时你对整棵树dp在时间上是不允许的。而这时你就要通过其他途径仅仅对要查询的关键点进行遍历,而减去其他的无用的树节点。这时,虚树这个概念就可以引入了。我们建立一颗仅仅包含部分关键结点的虚树,将非关键点构成的链简化成边或是剪去,在虚树上进行dp。
建立虚树之前的准备工作
- 首先你得对每个节点进行一次遍历,然后获取它们的时间序,同时也可以获取每个节点的深度,这对 后面LCA求最近公共祖先是必要的。
- 其次LCA算法,单次询问O(logn)的倍增
- 最后,对每个查询点进行时间序的排序。
如何构建虚树
我们使用栈stack来维护所谓的最右链,top为栈顶位置。在一开始,最右链上的边并没有被加入虚树,这是因为在接下来的过程中随时会有某个lca插到最右链中。初始无条件将第一个询问点加入栈stack中。然后接下来依次将排序后的节点加入,加入该询问点为now,此时它们的最近公共祖先lc = lca(now,stack[top]);要考虑stack[top]和stack[top - 1]以及lc之间的关系。
这里可以分情况讨论
1 lc = stack[top],也就是说now在stack[top]的子树中。此时只需将now入栈即可。
2 lc = stack[top - 1], 直接把lc -> stack[top]这条边加入虚树,stack[top]出栈,top-- ,再将now入栈。
3 lc 在stack[top]和stack[top-1]之间,也是先把lc -> stack[top]这台边加入虚树,stack[top]出栈,top–,
再将lc和now都入栈。与情况二有点类似。
4 dep[lc] < dep[stack[top - 1]],我们需要循环依次将stack[top - 1] -> stack[top]这条边依次加入虚树,stack[top]出栈,top–。直到出现情况二为止。
以上四种情况自己手动画下,如若实在不懂,可以去其他大佬的博客进行食用,他们的更加通俗易懂。
下面引入一个例子。
P2495 [SDOI2011]消耗战
一些问题
对于每次询问都要进行清除虚树构建,因为每一次询问都要重新构建一颗虚树。
这里涉及的最短路径和,我们预处理出minv[u]代表从1到u节点路径上最小的边权。如果u是询问点,那么切断u及其子树上询问点的最小代价dp[u] = minv[u],否则,最小代价dp(u)=min(minv[u],sum)其中sum是u节点的子节点的边权之和。这里得注意的是,本来你找到了最近的查询点后,本来不用管子节点此时的边权是最小的,但是,你还是得查询这个节点的儿子进行遍历清除整棵虚树。
然后直接上代码。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn = 500050;
int head[maxn],n,M,tol,dep[maxn],m[maxn],tol1;
int list[maxn],flag[maxn],stack[maxn],head1[maxn];
int fa[25][maxn],dfn[maxn],dfncnt = 1;
long long minv[maxn];
struct edge{
int to,next;
long long val;
}e[maxn << 1],e1[maxn << 1];
bool cmp(int u, int v)
{
return dfn[u] < dfn[v];
}
void add(int u, int v, long long w)
{
++tol;
e[tol].to = v;
e[tol].next = head[u];
e[tol].val = w;
head[u] = tol;
}
void dfs(int u, int f, int d)
{
int k;
dep[u] = d;
//后面要根据遍历的时间序来依次插入要询问的节点
dfn[u] = dfncnt ++;
//fa[k][u]表示对于u节点向上走2^k步所到达的位置
for( k = 0; fa[k][u]; k ++)
{
fa[k + 1][u] = fa[k][fa[k][u]];
}
m[u] = k; //所能走到的最大k的位置
for(int i = head[u]; i; i = e[i].next)
{
if(!dep[e[i].to])
{
fa[0][e[i].to] = u;
//到达每个节点所经过路径的最小的那条
minv[e[i].to] = min(minv[u],e[i].val);
dfs(e[i].to,u,d + 1);
}
}
}
int lca(int u,int v)
{
if(dep[u] < dep[v])
swap(u,v);
//枚举所能向上走的步数
for(int i = m[u]; i >= 0; i --)
{
if(dep[fa[i][u]] >= dep[v])
{
u = fa[i][u];
}
}
if(u == v)
return u;
for(int i = m[u]; i >= 0; i --)
{
if(fa[i][u] != fa[i][v])
{
u = fa[i][u];
v = fa[i][v];
}
}
return fa[0][u];
}
long long dfs1(int u)
{
long long sum = 0,Min;
for(int i = head1[u]; i ; i = e1[i].next)
{
sum += dfs1(e1[i].to);
}
//如果这个节点是标记点 即为查询目标点
if(flag[u] == 1)
Min = minv[u];
else
Min = min(minv[u],sum);
//清除虚树
flag[u] = 0;
head1[u] = 0;
return Min;
}
void add1(int u,int v)
{
++ tol1;
e1[tol1].to = v;
e1[tol1].next = head1[u];
head1[u] = tol1;
}
int main()
{
cin >> n;
minv[1] = 0x3f3f3f3f3f3f3f3f;
for(int i = 1;i < n; i ++)
{
int u,v;
long long w;
scanf("%d%d%lld",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
dfs(1,0,1);
cin >> M;
for(int i = 1; i <= M; i ++)
{
int k;
cin >> k;
for(int j = 1; j <= k; j ++)
{
scanf("%d",&list[j]);
//标记要询问的目标点
flag[list[j]] = 1;
}
//对list中的每个节点根据时间序排序
sort(list + 1, list + k + 1,cmp);
int top = 1;
stack[top] = list[1];
for(int j = 2; j <= k; j ++)
{
int now = list[j];
//找到它们的最近公共祖先
int lc = lca(now,stack[top]);
//构建虚树
while(1)
{
if(dep[lc] >= dep[stack[top - 1]])
{
if(lc != stack[top])
{
add1(lc,stack[top]);
if(lc != stack[top - 1])
stack[top] = lc;
else
top --;
}
break;
}else
{
add1(stack[top - 1],stack[top]);
top --;
}
}
stack[++top] = now;
}
//将最右链放入虚树
while(--top)
{
add1(stack[top],stack[top + 1]);
}
//计算结果
cout << dfs1(stack[1]) << endl;
//重新构建虚树初始化
tol1 = 0;
}
return 0;
}