第一眼就是树形DP,然而看到数据范围以后望而却步…… O(n×m) 的时间复杂度实在受不了啊……
观察数据范围,发现题目给出的是 ∑ki≤5×105 ,那我们就要考虑减少每次询问的时间复杂度,不能是 O(n) 的,应该和 ki 有关吧。
于是我们就想:有没有什么高级的数据结构,可以让我们的时间复杂度降下来,把多余的状态舍去呢?
显然虚树满足了我们的要求。那么到底什么是虚树呢?虚树又是怎么构建的呢?
所谓虚树,其实就是把询问中需要用到的点建到另一棵树上,把一部分无效信息删掉,把一部分信息合并,从而提高查询的效率。
对于一棵树上的两点,显然只有唯一路径,那么我们是不是可以把这一条路径上的所有信息合并,变成新的树上连接这两个点的一条边呢?这就是虚树的核心。
然后考虑怎么建树:
- 首先树根显然在虚树中;
- 其次,如果有两个关键点 x,y 在虚树中, lca(x,y) 也在虚树中。因为如果 lca(x,y) 不在虚树中,那么 1→x 的路径和 1→y 的路径会重复计算 1→lca(x,y) 的信息。
- 考虑用一个栈来维护虚树中的节点,我们分几种情况来讨论:(设当前栈顶的节点为
top
,栈中第二个节点为
pre
,当前要加入的节点为
x
,
lca(top,x) 为 t ,d[] 为各个节点的深度)
- top=t ,那么直接把 x 压入栈中即可。
d[top]>d[t] ,那么显然 t 需要入栈,这里又需要分成两种情况讨论:
d[pre]>d[t] ,将 top 和 pre 连边,重复上一重判断。-
d[pre]≤d[t]
,将
top
和
t
连边,将
t 和 x 依次压入栈中。
这段写的可能有些乱,希望各位大佬画图理解一下吧。
于是我们用虚树把树形DP的时间复杂度降到了
附上AC代码:
#include <cstdio>
#include <cctype>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=25e4+10;
struct node{
int h[N],num;
struct side{
int to,w,nt;
}s[N<<1];
inline void add(int x,int y,int w){
if (x==y) return;
s[++num]=(side){y,w,h[x]},h[x]=num;
}
}s1,s2;
int n,wz[N],size,d[N],f[N][20],m,a[N],sk[N];
ll w[N];
bool b[N];
inline char nc(void){
static char ch[100010],*p1=ch,*p2=ch;
return p1==p2&&(p2=(p1=ch)+fread(ch,1,100010,stdin),p1==p2)?EOF:*p1++;
}
inline void read(int &a){
static char c=nc();int f=1;
for (;!isdigit(c);c=nc()) if (c=='-') f=-1;
for (a=0;isdigit(c);a=(a<<3)+(a<<1)+c-'0',c=nc());
return (void)(a*=f);
}
inline void so(int x){
wz[x]=++size;
for (int i=1; (1<<i)<=d[x]; ++i) f[x][i]=f[f[x][i-1]][i-1];
for (int i=s1.h[x]; i; i=s1.s[i].nt)
if (s1.s[i].to!=f[x][0]){
d[s1.s[i].to]=d[f[s1.s[i].to][0]=x]+1;
w[s1.s[i].to]=min(w[x],1ll*s1.s[i].w);
so(s1.s[i].to);
}
return;
}
inline bool cmp(int a,int b){return wz[a]<wz[b];}
inline int lca(int x,int y){
if (d[x]<d[y]) swap(x,y);
for (int i=17; i>=0; --i) if (d[f[x][i]]>=d[y]) x=f[x][i];
if (x==y) return x;
for (int i=17; i>=0; --i) if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
inline ll dp(int x){
if (b[x]) return w[x];ll ret=0;
for (int i=s2.h[x]; i; i=s2.s[i].nt) ret+=dp(s2.s[i].to);
return min(w[x],ret);
}
inline void clear(int x){
for (int i=s2.h[x]; i; i=s2.s[i].nt) clear(s2.s[i].to);
return (void)(s2.h[x]=0,b[x]=0);
}
int main(void){
read(n);
for (int i=1,x,y,v; i<n; ++i) read(x),read(y),read(v),s1.add(x,y,v),s1.add(y,x,v);
w[1]=1ll<<60,d[1]=1,so(1),read(m);
while (m--){
int len,top;read(len),s2.num=0;
for (int i=1; i<=len; ++i) read(a[i]),b[a[i]]=1;
sort(a+1,a+1+len,cmp);
sk[top=1]=1;
for (int i=1; i<=len; ++i){
int tmp=lca(a[i],sk[top]);
while (d[sk[top]]>d[tmp])
if (d[sk[top-1]]<d[tmp]) s2.add(tmp,sk[top],0),sk[top]=tmp;
else s2.add(sk[top-1],sk[top],0),--top;
if (sk[top]!=a[i]) sk[++top]=a[i];
}
while (top>1) s2.add(sk[top-1],sk[top],0),--top;
printf("%lld\n",dp(1)),clear(1);
}
return 0;
}