所谓的虚树,就是只保留有用节点,把其他点都缩掉。
比如此题,对于每一个询问我们都去树形dp一遍的话就是
O(nm)
,T飞。
但是因为总的K<=500000,我们考虑可不可以每次询问时把树缩一缩,把树的大小控制在
O(k)
,这样总的复杂度就是
O(k)
得了。
可以的!我们考虑只保留询问的K个关键点,以及这些关键点两两之间的lca,可以证明,lca最多只有K-1个。这样这棵树就是 O(K) 得了!
我们考虑怎么把这棵虚树建出来,把所有关键点按dfs序排序,然后维护一个栈,栈中节点始终是一条链上的。考虑新加入的节点x和栈顶节点y,以及他们的lca t。
如果y==t,说明x是y的后代,直接加进栈中即可。
否则一直弹栈直到栈顶元素深度小于等于lca的深度,这时看栈顶元素是不是lca,如果不是的话就把lca入栈,再把x入栈,这样栈中节点还是一条链上的。
把y弹出栈时可以建边(qq[top],y),因为他们之间不可能有新的lca了。
这样我们就得到了一颗 O(k) 的虚树,去树形dp即可。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 250010
#define ll long long
#define inf 1LL<<60
inline char gc(){
static char buf[1<<16],*S,*T;
if(S==T){T=(S=buf)+fread(buf,1,1<<16,stdin);if(T==S) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=gc();
return x*f;
}
int n,h[N],num=0,dfn[N],dfnum=0,fa[N][20],dep[N],Log[N],m,a[N],qq[N];
ll f[N],mn[N];bool mark[N];
struct edge{
int to,next,val;
}data[N<<1];
inline void add(int x,int y){
data[++num].to=y;data[num].next=h[x];h[x]=num;
}
inline void dfs(int x){
dfn[x]=++dfnum;
for(int i=1;i<=Log[n];++i){
if(!fa[x][i-1]) break;
fa[x][i]=fa[fa[x][i-1]][i-1];
}for(int i=h[x];i;i=data[i].next){
int y=data[i].to;if(y==fa[x][0]) continue;
fa[y][0]=x;dep[y]=dep[x]+1;mn[y]=min(mn[x],(ll)data[i].val);dfs(y);
}
}
inline int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
int d=dep[x]-dep[y];
for(int i=0;i<=Log[d];++i)
if(d>>i&1) x=fa[x][i];
if(x==y) return x;
for(int i=Log[n];i>=0;--i)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
inline void dp(int x){
ll tmp=0;
for(int i=h[x];i;i=data[i].next){
int y=data[i].to;dp(y);tmp+=f[y];
}if(mark[x]) f[x]=mn[x],mark[x]=0;//一定要砍
else f[x]=min(tmp,mn[x]);h[x]=0;
}
inline void solve(){
m=read();int top=0;num=0;
for(int i=1;i<=m;++i) a[i]=read(),mark[a[i]]=1;sort(a+1,a+m+1,cmp);
qq[++top]=1;
for(int i=1;i<=m;++i){
int t=lca(qq[top],a[i]);
while(dep[qq[top]]>dep[t]){
int x=qq[top--];
if(dep[qq[top]]<dep[t]) qq[++top]=t;
add(qq[top],x);
}if(a[i]!=qq[top]) qq[++top]=a[i];
}int x=qq[top--];while(top) add(qq[top],x),x=qq[top],--top;
dp(1);printf("%lld\n",f[1]);
}
int main(){
// freopen("a.in","r",stdin);
n=read();Log[0]=-1;mn[1]=inf;
for(int i=1;i<=n;++i) Log[i]=Log[i>>1]+1;
for(int i=1;i<n;++i){
int x=read(),y=read(),val=read();
data[++num].to=y;data[num].next=h[x];h[x]=num;data[num].val=val;
data[++num].to=x;data[num].next=h[y];h[y]=num;data[num].val=val;
}dfs(1);int owo=read();memset(h,0,sizeof(h));num=0;
while(owo--) solve();
return 0;
}