虚树入门题。
所谓虚树就是只保留需要的关键节点及互相的lca进行重建树,在虚树上跑DP之类的,使得复杂度与关键点个数相关,而且可能能简化问题。
虚树的建法是按dfs序的顺序加入,维护一个当前链的栈,类似dfs一样搞搞就好了。网上有一些资料…
建虚树代码:
void buildVT(){
inT.clear(); sort(S.begin(),S.end(),_cmp);
stk[top=1]=1; inT.push_back(1);
for(int i=0;i<S.size();i++){
int x=S[i], _lca=LCA(x,stk[top]); inT.push_back(x); inT.push_back(_lca);
while(top>1&&dfn[stk[top-1]]>dfn[_lca]) Vadd(stk[top-1],stk[top]), top--;
if(top==1||stk[top]==_lca){ stk[++top]=x; continue; }
if(stk[top-1]==_lca) Vadd(stk[top-1],stk[top]), top--; else Vadd(_lca,stk[top]), stk[top]=_lca;
stk[++top]=x;
}
while(top>1) Vadd(stk[top-1],stk[top]), top--;
}
这题树形DP很显然,建好虚树的边权是对应原树中链上最小边…
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;
const int maxn=300005,maxe=600005,maxq=300005;
int n,Q,Tim,dfn[maxn],fir[maxn],nxt[maxe],son[maxe],w[maxe],tot;
void add(int x,int y,int z){
son[++tot]=y; w[tot]=z; nxt[tot]=fir[x]; fir[x]=tot;
}
int dep[maxn],anc[maxn][20],_min[maxn][20];
void dfs_info(int x,int pre,int w_pre){
dfn[x]=++Tim; anc[x][0]=pre; _min[x][0]=w_pre;
for(int i=1;i<=18;i++) anc[x][i]=anc[anc[x][i-1]][i-1], _min[x][i]=min(_min[x][i-1],_min[anc[x][i-1]][i-1]);
for(int j=fir[x];j;j=nxt[j]) if(son[j]!=pre){
dep[son[j]]=dep[x]+1; dfs_info(son[j],x,w[j]);
}
}
int LCA(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=18;i>=0;i--) if(dep[anc[x][i]]>=dep[y]) x=anc[x][i];
if(x==y) return x;
for(int i=18;i>=0;i--) if(anc[x][i]!=anc[y][i]) x=anc[x][i], y=anc[y][i];
return anc[x][0];
}
vector<int> S,inT;
int _cmp(int A,int B){ return dfn[A]<dfn[B]; }
int stk[maxn],top,vfir[maxn],vnxt[maxe],vw[maxn],vson[maxe],vtot;
int Query_min(int x,int y){
int res=1e9;
for(int i=18;i>=0;i--) if(dep[anc[x][i]]>=dep[y]) res=min(res,_min[x][i]), x=anc[x][i];
return res;
}
void Vadd(int x,int y){
vson[++vtot]=y; vw[vtot]=Query_min(y,x); vnxt[vtot]=vfir[x]; vfir[x]=vtot;
//printf("%d --- %d %d\n",x,y,vw[vtot]);
}
void buildVT(){
inT.clear(); sort(S.begin(),S.end(),_cmp);
stk[top=1]=1; inT.push_back(1);
for(int i=0;i<S.size();i++){
int x=S[i], _lca=LCA(x,stk[top]); inT.push_back(x); inT.push_back(_lca);
while(top>1&&dfn[stk[top-1]]>dfn[_lca]) Vadd(stk[top-1],stk[top]), top--;
if(top==1||stk[top]==_lca){ stk[++top]=x; continue; }
if(stk[top-1]==_lca) Vadd(stk[top-1],stk[top]), top--; else Vadd(_lca,stk[top]), stk[top]=_lca;
stk[++top]=x;
}
while(top>1) Vadd(stk[top-1],stk[top]), top--;
}
int b[maxn];
LL f[maxn];
bool vis[maxn];
void dfs_dp(int x,int pre){
if(vis[x]) f[x]=1e18; else{
f[x]=0;
for(int j=vfir[x];j;j=vnxt[j]) if(vson[j]!=pre)
dfs_dp(vson[j],x), f[x]+=min(f[vson[j]],(LL)vw[j]);
}
}
void Solve(){
S.clear();
scanf("%d",&b[0]);
for(int i=1;i<=b[0];i++) scanf("%d",&b[i]), S.push_back(b[i]), vis[b[i]]=true;
buildVT();
dfs_dp(1,1); printf("%lld\n",f[1]);
for(int i=1;i<=b[0];i++) vis[b[i]]=false;
vtot=0; for(int i=0;i<inT.size();i++) vfir[inT[i]]=0;
}
int main(){
freopen("bzoj2286.in","r",stdin);
freopen("bzoj2286.out","w",stdout);
scanf("%d",&n);
for(int i=1;i<=n-1;i++){
int x,y,z; scanf("%d%d%d",&x,&y,&z);
add(x,y,z); add(y,x,z);
}
dfs_info(1,1,1e9);
scanf("%d",&Q);
while(Q--) Solve();
return 0;
}