消耗战-虚树+树形dp
题目描述
题解
首先考虑 m = 1 m=1 m=1,也就是只有一次询问的情况。我们考虑暴力 d p dp dp
设 f [ x ] f[x] f[x]为处理完以 x x x为根的树的最小代价
转移分为两种情况
1 1 1.断开自己与父亲的联系,代价为从根到该节点的最小值(记为 m i n x [ y ] , y minx[y],y minx[y],y为当前节点)
2 2 2.不考虑该节点(前提是该节点不是询问点),把子树内的所有询问点都断开的代价
即 f [ x ] = ∑ ( m i n ( m i n x [ y ] , f [ y ] ) f[x]=∑(min(minx[y],f[y]) f[x]=∑(min(minx[y],f[y]) ( ( ( y y y是 x x x的子节点 ) ) )
但是这样的复杂度是 O ( n m ) O(nm) O(nm)的,显然无法 A C AC AC
然而我们发现 ∑ k ∑k ∑k是比较小的,我们可不可以对 k k k下手呢?
于是,虚树诞生了
虚树学习博客:https://www.cnblogs.com/zwfymqz/p/9175152.html
代码实现
#include<bits/stdc++.h>//虚树
#define M 250009
using namespace std;
int nxt[M*2],to[M*2],first[M],tot,cnt;
int n,m,k,dep[M],f[M][23],dfn[M],q[M],top,a[M];
long long minx[M],w[M*2];
vector<int>v[M];
int read(){
int f=1,re=0;
char ch;
for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
if(ch=='-'){f=-1,ch=getchar();}
for(;isdigit(ch);ch=getchar()) re=(re<<3)+(re<<1)+ch-'0';
return re*f;
}
bool cmp(const int &a,const int &b){
return dfn[a]<dfn[b];
}
void add(int x,int y,int z){
nxt[++tot]=first[x],first[x]=tot,to[tot]=y,w[tot]=z;
nxt[++tot]=first[y],first[y]=tot,to[tot]=x,w[tot]=z;
}
void init(int u,int fa){
for(int i=1;i<=20;i++)
f[u][i]=f[f[u][i-1]][i-1];
dep[u]=dep[fa]+1,dfn[u]=++cnt;
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(v==fa) continue;
minx[v]=min(minx[u],w[i]);
f[v][0]=u,init(v,u);
}
}
int LCA(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=20;i>=0;i--){
if(dep[f[x][i]]>=dep[y]) x=f[x][i];
if(x==y) return x;
}
for(int i=20;i>=0;i--)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
void insert(int x){
if(top==1){q[++top]=x;return;}
int lca=LCA(x,q[top]);
if(lca==q[top]) return;//此处x不用入栈
//因为这样insert的话任意时刻如果lca=st[tp]的话那么lca一定是标记点(lca=1除外)
//既然lca和x都是标记点,那么x就不用管了呗,相当于把x的决策合并到lca上去了。
while(top>1&&dfn[q[top-1]]>=dfn[lca]) v[q[top-1]].push_back(q[top]),top--;
if(q[top]!=lca) v[lca].push_back(q[top]),q[top]=lca;
q[++top]=x;
}
long long getans(int x){
if(v[x].size()==0) return minx[x];
long long sum=0;
for(int i=0;i<v[x].size();i++){
int u=v[x][i];
sum+=getans(u);
}v[x].clear();
return min(sum,minx[x]);
}
signed main(){
n=read(),minx[1]=1ll<<60;
int x,y,z;
for(int i=1;i<n;i++){
x=read(),y=read(),z=read();
add(x,y,z);
}init(1,0),m=read();
for(int cas=1;cas<=m;cas++){
k=read();
for(int i=1;i<=k;i++) a[i]=read();
sort(a+1,a+k+1,cmp);
q[top=1]=1;
for(int i=1;i<=k;i++) insert(a[i]);
while(top>1) v[q[top-1]].push_back(q[top]),top--;
printf("%lld\n",getans(1));
}
}