无脑的dp,不想写题解了
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<vector>
#define maxn 1000010
using namespace std;
int head[maxn],to[2*maxn],next[2*maxn];
vector<int> v[maxn],g[maxn];
int size[maxn],dep[maxn],fa[20][maxn];
int a[maxn],in[maxn],st[maxn];
int n,m,num,tot,T,top,k;
long long ans,sum[maxn];
bool flag[maxn];
int mx[maxn],mn[maxn],MX,MN;
void addedge(int x,int y)
{
num++;to[num]=y;next[num]=head[x];head[x]=num;
}
void add(int x,int y)
{
v[x].push_back(y);g[x].push_back(dep[y]-dep[x]);
}
void dfs(int x)
{
in[x]=++tot;
for (int p=head[x];p;p=next[p])
if (to[p]!=fa[0][x]) fa[0][to[p]]=x,dep[to[p]]=dep[x]+1,dfs(to[p]);
}
bool cmp(int x,int y) {return in[x]<in[y];}
void DP(int x)
{
if (flag[x]) size[x]=1,mx[x]=0,mn[x]=0;
else size[x]=0,mx[x]=-n,mn[x]=n;
sum[x]=0;
for (int i=0;i<v[x].size();i++) DP(v[x][i]),size[x]+=size[v[x][i]];
for (int i=0;i<v[x].size();i++)
{
int y=v[x][i],len=g[x][i];
ans+=1ll*(sum[y]+1ll*size[y]*len)*(size[x]-size[y]);
MX=max(MX,mx[x]+mx[y]+len);MN=min(MN,mn[x]+mn[y]+len);
sum[x]+=(long long)sum[y]+1ll*size[y]*len;
mx[x]=max(mx[x],mx[y]+len);mn[x]=min(mn[x],mn[y]+len);
}
v[x].clear();g[x].clear();flag[x]=0;
}
int go_up(int x,int d)
{
for (int i=0;i<=19;i++)
if (d&(1<<i)) x=fa[i][x];
return x;
}
int LCA(int x,int y)
{
if (dep[x]>dep[y]) x=go_up(x,dep[x]-dep[y]);
else y=go_up(y,dep[y]-dep[x]);
if (x==y) return x;
for (int i=19;i>=0;i--) if (fa[i][x]!=fa[i][y]) x=fa[i][x],y=fa[i][y];
return fa[0][x];
}
int main()
{
scanf("%d",&n);
for (int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
addedge(x,y);addedge(y,x);
}
dfs(1);
for (int j=1;j<=19;j++)
for (int i=1;i<=n;i++)
fa[j][i]=fa[j-1][fa[j-1][i]];
scanf("%d",&T);
while (T--)
{
top=1;st[1]=1;
scanf("%d",&k);
for (int i=1;i<=k;i++) scanf("%d",&a[i]),flag[a[i]]=1;
sort(a+1,a+k+1,cmp);
for (int i=1;i<=k;i++)
{
int lca=LCA(st[top],a[i]);
while (dep[st[top]]>dep[lca])
if (dep[st[top-1]]<dep[lca]) add(lca,st[top]),st[top]=lca;
else add(st[top-1],st[top]),top--;
if (st[top]!=a[i]) st[++top]=a[i];
}
while (top>1) add(st[top-1],st[top]),top--;
ans=0;MX=-n;MN=n;
DP(1);
printf("%lld %d %d\n",ans,MN,MX);
}
return 0;
}