关键还是在于K个点两两之间的lca最多只有k-1个。
因此我们可以建出虚树,然后一遍dp求出所有答案。
复杂度
O(K)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
#define ll long long
#define inf 0x3f3f3f3f
#define N 1000010
inline char gc(){
static char buf[1<<16],*S,*T;
if(T==S){T=(S=buf)+fread(buf,1,1<<16,stdin);if(S==T) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=gc();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=gc();
return x*f;
}
int n,h[N],num=0,dfn[N],dfnum=0,Mn[N<<1][20],Log[N<<1],dep[N],pos[N],tot=0;
int a[N],qq[N],b[N],sz[N],mn[N],mx[N],c[N],ansmn,ansmx;ll ans=0;
struct edge{
int to,next;
}data[N<<1];
inline void add(int x,int y){
data[++num].to=y;data[num].next=h[x];h[x]=num;
}
void dfs(int x,int Fa){
dfn[x]=++dfnum;Mn[++tot][0]=x;pos[x]=tot;
for(int i=h[x];i;i=data[i].next){
int y=data[i].to;if(y==Fa) continue;
dep[y]=dep[x]+1;dfs(y,x);Mn[++tot][0]=x;
}
}
inline int Min(int x,int y){return dep[x]<dep[y]?x:y;}
inline void initrmq(){
for(int i=1;i<=Log[n*2-1];++i)
for(int j=1;j<=n*2;++j){
if(j+(1<<i-1)>2*n-1) break;
Mn[j][i]=Min(Mn[j][i-1],Mn[j+(1<<i-1)][i-1]);
}
}
inline int lca(int x,int y){
x=pos[x];y=pos[y];if(x>y) swap(x,y);int t=Log[y-x+1];
return Min(Mn[x][t],Mn[y-(1<<t)+1][t]);
}
void dfs1(int x){
if(sz[x]) mn[x]=mx[x]=dep[x];
else mn[x]=inf,mx[x]=-inf;c[++tot]=x;
for(int i=h[x];i;i=data[i].next){
int y=data[i].to;dfs1(y);
ans-=(ll)sz[x]*sz[y]*dep[x]*2;sz[x]+=sz[y];
ansmn=min(ansmn,mn[x]+mn[y]-2*dep[x]);
mn[x]=min(mn[x],mn[y]);
ansmx=max(ansmx,mx[x]+mx[y]-2*dep[x]);
mx[x]=max(mx[x],mx[y]);
}
}
inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
inline void solve(){
int m=read(),top=0;num=0;ans=0;tot=0;ansmn=inf;ansmx=0;
for(int i=1;i<=m;++i) a[i]=read(),ans+=(ll)dep[a[i]]*(m-1),sz[a[i]]=1;
sort(a+1,a+m+1,cmp);qq[++top]=1;
for(int i=1;i<=m;++i){
int t=lca(a[i],qq[top]);
while(dep[qq[top]]>dep[t]){
int x=qq[top--];
if(dep[qq[top]]<dep[t]) qq[++top]=t;
add(qq[top],x);
}if(a[i]!=qq[top]) qq[++top]=a[i];
}int x=qq[top--];while(top) add(qq[top],x),x=qq[top--];
dfs1(1);printf("%lld %d %d\n",ans,ansmn,ansmx);
for(int i=1;i<=tot;++i) h[c[i]]=sz[c[i]]=0;
}
int main(){
// freopen("a.in","r",stdin);
n=read();Log[0]=-1;
for(int i=1;i<=n*2;++i) Log[i]=Log[i>>1]+1;
for(int i=1;i<n;++i){
int x=read(),y=read();add(x,y);add(y,x);
}dfs(1,0);initrmq();memset(h,0,sizeof(h));
int owo=read();while(owo--) solve();
return 0;
}