P2495 [SDOI2011]消耗战
树形dp
状态表示:
f
u
f_u
fu表示以
u
u
u为根的子树中,
u
u
u节点与子树中的关键的“隔开”所需要的最小代价
状态转移:
考虑
u
u
u的一个儿子
v
v
v
- v v v是关键点: f u = f u + w u → v f_u=f_u+w_{u\to v} fu=fu+wu→v
- v v v不是关键的: f u = f u + min ( w u → v , f v ) f_u=f_u+\min(w_{u\to v},f_v) fu=fu+min(wu→v,fv)
于是有下面暴力代码 O ( n m + ∑ k ) O(nm+\sum k) O(nm+∑k)
#define IO ios::sync_with_stdio(false);cin.tie();cout.tie(0)
//#pragma GCC optimize(2)
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
using ll=long long;
constexpr int N=250010;
int h[N],e[2*N],ne[2*N],w[2*N],idx;
void add(int a,int b,int c){e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;}
ll f[N];
int n,m;
bool mp[N];
void dfs(int u,int fa)
{
for(int i=h[u];i!=-1;i=ne[i])
{
int v=e[i];
if(v==fa) continue;
dfs(v,u);
if(mp[v])
f[u]+=w[i];
else
f[u]+=min((ll)w[i],f[v]);
}
}
int main()
{
IO;
cin>>n;
memset(h,-1,sizeof h);
for(int i=1;i<n;i++)
{
int a,b,c;
cin>>a>>b>>c;
add(a,b,c),add(b,a,c);
}
cin>>m;
while(m--)
{
memset(f,0ll,sizeof(ll)*(n+1));
memset(mp,0,sizeof(bool)*(n+1));
int k;
cin>>k;
while(k--)
{
int p;
cin>>p;
mp[p]=1;
}
dfs(1,0);
cout<<f[1]<<'\n';
}
return 0;
}
虚树:根据原树构建一颗虚拟的树,这棵树 只 包 含 {\color{red}只包含} 只包含关键节点以及关键节点的最近公共祖先(LCA)
构建过程:
- 将关键节点的时间戳排序
- 用一个栈维护一条虚树上的链,根节点到当前关键节点,将关键点依次push进栈中
push过程:
大佬图片非常清晰
下面用top表示栈顶元素,cur表示当前需要插入的节点
anc=lca(top,cur)
首先如果说anc=top,说明cur应该接在栈顶后面即可
否则会出现下面情况
当前栈中维护的是
绿
色
{\color{green}绿色}
绿色的那一条链,我们需要让当前栈维护从根节点到当前节点即
蓝
色
{\color{blue}蓝色}
蓝色那条链,只需要让top-1向top连边,并且不断top- -即可。
最后需要判断是否存在下面情况
如果存在,需要连一条anc到top的边,然后弹出top- -,并将anc入栈
代码如下
void insert(int u)
{
int anc=lca(u,stk[tt]);
while(tt>1&&dfn[stk[tt-1]]>=dfn[anc])
E[stk[tt-1]].push_back(stk[tt]),tt--;
if(stk[tt]!=anc) E[anc].push_back(stk[tt]),stk[tt]=anc;//最后的情况
stk[++tt]=u;
}
void build()// 构建虚树
{
sort(is+1,is+1+m,[](const int &a,const int &b){return dfn[a]<dfn[b];});//按照dfn排序
stk[tt=1]=1;//根节点
for(int i=1;i<=m;i++) insert(is[i]);
while(tt) E[stk[tt-1]].push_back(stk[tt]),tt--;// 连边
}
显然虚树两点(u,v)之间边 u → v u\to v u→v的大小应为原树中路径 min ( u ⇝ v ) \min(u \leadsto v) min(u⇝v),倍增求lca过程即可求出边权。
而下面代码采取另一种做法:
首先预处理原树中根节点到当前节点路径的最小值
d
p
u
dp_u
dpu表示从u开始不能到达其子树中的关键点所需切断的最小边权和。
切断儿子
v
v
v要么用
m
n
[
v
]
mn[v]
mn[v]要么切断子树,如果当前节点是关键节点,必须切断当前节点即花费代价为
m
n
[
u
]
mn[u]
mn[u]
注意:虚树清空节点需要在dfs过程中清空,不能使用memset
时间复杂度 ∑ k log n \sum k\log n ∑klogn
#define IO ios::sync_with_stdio(false);cin.tie();cout.tie(0)
//#pragma GCC optimize(2)
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
using ll=long long;
constexpr int N=250010;
int h[N],e[2*N],ne[2*N],w[2*N],idx;
void add(int a,int b,int c){e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;}
int dep[N],sz[N],son[N],fa[N];
int mn[N];
void dfs1(int u)
{
sz[u]=1;
for(int i=h[u];i!=-1;i=ne[i])
{
int v=e[i];
if(v==fa[u]) continue;
fa[v]=u;
dep[v]=dep[u]+1;
mn[v]=min(mn[u],w[i]);
dfs1(v);
sz[u]+=sz[v];
if(sz[son[u]]<sz[v]) son[u]=v;
}
}
int top[N],dfn[N],timestamp;
void dfs2(int u,int t)
{
dfn[u]=++timestamp;
top[u]=t;
if(!son[u]) return;
dfs2(son[u],t);
for(int i=h[u];i!=-1;i=ne[i])
{
int v=e[i];
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
int lca(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[top[u]];
}
return dep[u]<dep[v]?u:v;
}
//====================================================树剖求lca
int n,m;
int stk[N],tt;
int is[N];
bool mp[N];
vector<int> E[N];
void insert(int u)
{
int anc=lca(u,stk[tt]);
while(tt>1&&dfn[stk[tt-1]]>=dfn[anc])
E[stk[tt-1]].push_back(stk[tt]),tt--;
if(stk[tt]!=anc) E[anc].push_back(stk[tt]),stk[tt]=anc;
stk[++tt]=u;
}
void build()// 构建虚树
{
sort(is+1,is+1+m,[](const int &a,const int &b){return dfn[a]<dfn[b];});//按照dfn排序
stk[tt=1]=1;//根节点
for(int i=1;i<=m;i++) insert(is[i]);
while(tt) E[stk[tt-1]].push_back(stk[tt]),tt--;
}
ll dfs3(int u)
{
ll cost=0;
for(int v:E[u]) cost+=min((ll)mn[v],dfs3(v));
E[u].clear();
if(mp[u]) return mn[u];
else return cost;
}
int main()
{
IO;
cin>>n;
memset(h,-1,sizeof h);
memset(mn,0x3f,sizeof mn);
for(int i=1;i<n;i++)
{
int a,b,c;
cin>>a>>b>>c;
add(a,b,c),add(b,a,c);
}
dfs1(1);
dfs2(1,1);
int q;
cin>>q;
while(q--)
{
cin>>m;
for(int i=1;i<=m;i++)
{
cin>>is[i];
mp[is[i]]=1;
}
build();
cout<<dfs3(1)<<'\n';
for(int i=1;i<=m;i++) mp[is[i]]=0;
}
return 0;
}
要加油哦~