题目:
题解:
看到所有k的和<=2*n各位应该明白我们应该用虚树了,这题的虚树也很好找嘛,不就是把不用的链缩起来弄成长度和呗
这个和我也会做,深度从深到浅考虑虚树中的点,每个点往上跳一条边的时候,这条边的权值将被算入除ta子树外的所有节点中,记录一下就可以了
这个最大最小怎么办呢?我们可以再维护几个数组,Max[i]表示子树关键点到i的最长链,maxx[i]表示次长链;Min[i]表示最短链,minn[i]表示次短链。然后拼一下就好了【树的直径阴影】
看清楚啊,人家让先输出小的再输出大的。。。。
代码:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define INF 1e9
#define LL long long
using namespace std;
const int N=1000005;
const int sz=19;
int tot,nxt[N*2],point[N],v[N*2],f[N][sz],k,m,n,size[N],mi[sz],nn,in[N],out[N],h[N],dis[N],gj[N],ask[N],flag[N],top,stack[N];
LL he[N],Max[N],maxx[N],Min[N],minn[N],ansma,ansmn,c[N*2];
void addline(int x,int y,LL w){++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=w;}
void dfs(int x,int fa)
{
in[x]=++nn;h[x]=h[fa]+1;
for (int i=1;i<sz;i++)
if (h[x]<mi[i]) break;
else f[x][i]=f[f[x][i-1]][i-1];
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa) dis[v[i]]=dis[x]+1,f[v[i]][0]=x,dfs(v[i],x);
out[x]=nn;
}
int lca(int x,int y)
{
if (h[x]<h[y]) swap(x,y);
int k=h[x]-h[y];
for (int i=0;i<sz;i++)
if (k&(1<<i)) x=f[x][i];
if (x==y) return x;
for (int i=sz-1;i>=0;i--)
if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
void treedp(int x)
{
he[x]=0;size[x]=0; Max[x]=maxx[x]=0; Min[x]=minn[x]=INF;
if (gj[x]==m) size[x]++,Min[x]=0; int ez=0;
for (int i=point[x];i;i=nxt[i])
{
ez++;
treedp(v[i]);
if (Max[v[i]]+c[i]>Max[x])
{
maxx[x]=Max[x];
Max[x]=Max[v[i]]+c[i];
}else maxx[x]=max(maxx[x],Max[v[i]]+c[i]);//大
if (Min[v[i]]+c[i]<Min[x])
{
minn[x]=Min[x];
Min[x]=Min[v[i]]+c[i];
}else minn[x]=min(minn[x],Min[v[i]]+c[i]);//小
size[x]+=size[v[i]];
he[x]+=he[v[i]]+c[i]*(LL)(k-size[v[i]])*(LL)size[v[i]]; //和
}
if (gj[x]==m || ez>1)
{
ansma=max(ansma,Max[x]+maxx[x]);
ansmn=min(ansmn,Min[x]+minn[x]);
}
point[x]=0;
}
int cmp(int a,int b){return in[a]<in[b];}
void work()
{
scanf("%d",&k);
for (int i=1;i<=k;i++) scanf("%d",&ask[i]),gj[ask[i]]=flag[ask[i]]=m;
sort(ask+1,ask+k+1,cmp); ask[0]=k;
for (int i=2;i<=k;i++)
{
int t=lca(ask[i],ask[i-1]);
if (flag[t]!=m) flag[t]=m,ask[++ask[0]]=t;
}
if (flag[1]!=m) flag[1]=m,ask[++ask[0]]=1;
sort(ask+1,ask+ask[0]+1,cmp);
tot=0; stack[top=1]=1;
for (int i=2;i<=ask[0];i++)
{
while (in[ask[i]]<in[stack[top]] || in[ask[i]]>out[stack[top]]) top--;
addline(stack[top],ask[i],dis[stack[top]]+dis[ask[i]]-2*dis[lca(stack[top],ask[i])]);
stack[++top]=ask[i];
}
ansma=-INF; ansmn=INF;
treedp(1);
printf("%lld %lld %lld\n",he[1],ansmn,ansma);
}
int main()
{
mi[0]=1;
for (int i=1;i<sz;i++) mi[i]=mi[i-1]*2;
scanf("%d",&n);
for (int i=1;i<n;i++)
{
int x,y;scanf("%d%d",&x,&y);
addline(x,y,1); addline(y,x,1);
}
dfs(1,0);memset(point,0,sizeof(point));
scanf("%d",&m);
while (m) work(),m--;
}