题意
给你n个点的树,多组询问,每次询问选定其中k个点,问从这k个点两两之间的距离之和,距离最小值,距离最大值。\(n\leqslant 1e6,q\leqslant 50000,\sum{k_i} \leqslant 2\times n\)
题解
首先根据\(\sum{k_i} \leqslant 2\times n\)可得,这题需要用到虚树。
然后我们就把虚树建出来,然后把询问点标记一下。在虚树中,对于距离之和,我们只要算出每条边下面的子树内有多少标记点,统计这条边的贡献就行了。然后对于每个点x,dfs算出子树内(包括点x)距x点最近、第二近的标记点的距离,然后将这个两个值加起来取个min就行了。距离最大值同理即可。
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define inf 0x7f7f7f
#define ll long long
using namespace std;
const int maxn=1e6;
int n,tot,Time,k,m;
int pre[maxn*2+8],now[maxn+8],son[maxn*2+8];
int dep[maxn+8],dfn[maxn+8];
int f[maxn*2+8][21];
int a[maxn+8],st[maxn+8];
ll ans1;
int ans2,ans3;
void add(int u,int v)
{
pre[++tot]=now[u];
now[u]=tot;
son[tot]=v;
}
int read()
{
int x=0,f=1;char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x*f;
}
void Build_Tree(int x,int fa)
{
dep[x]=dep[fa]+1;
f[dfn[x]=++Time][0]=x;
for (int p=now[x];p;p=pre[p])
{
int child=son[p];
if (child==fa) continue;
Build_Tree(child,x);
f[++Time][0]=x;
}
}
int RMQ(int x,int y){return dep[x]<dep[y]?x:y;}
void Build_St()
{
for (int j=1;j<=log(Time)/log(2);j++)
for (int i=1;i<=Time-(1<<j)+1;i++)
f[i][j]=RMQ(f[i][j-1],f[i+(1<<(j-1))][j-1]);
}
int Get_Lca(int x,int y)
{
if (dfn[x]>dfn[y]) swap(x,y);
int len=dfn[y]-dfn[x]+1,t=log(len)/log(2);
return RMQ(f[dfn[x]][t],f[dfn[y]-(1<<t)+1][t]);
}
struct node
{
int x,y;
};
node Update_Big(node a,node b)
{
if (a.x<b.x) swap(a,b);
return (node){a.x,max(a.y,b.x)};
}
node Update_Small(node a,node b)
{
if (a.x>b.x) swap(a,b);
return (node){a.x,min(a.y,b.x)};
}
struct Virtual_Tree
{
int rt,tot;
int st[maxn+8],tail;
int pre[maxn*2+8],now[maxn+8],son[maxn*2+8],val[maxn*2+8];
node f[maxn+8],g[maxn+8];
bool color[maxn+8];
int siz[maxn+8];
void clear()
{
tot=0;
while(tail) now[st[tail--]]=0;
}
void add(int u,int v)
{
//printf("line:%d %d\n",u,v);
if (!now[u]) st[++tail]=u;
pre[++tot]=now[u];
now[u]=tot;
son[tot]=v;
val[tot]=abs(dep[u]-dep[v]);
}
void dfs(int x,int fa)
{
f[x]=(node){color[x]?0:inf,inf};
g[x]=(node){color[x]?0:-inf,-inf};
siz[x]=color[x];
for (int p=now[x];p;p=pre[p])
{
int child=son[p];
if (child==fa) continue;
dfs(child,x);
siz[x]+=siz[child];
ans1+=1ll*siz[child]*(m-siz[child])*val[p];
f[x]=Update_Small(f[x],(node){f[child].x+val[p],inf});
g[x]=Update_Big(g[x],(node){g[child].x+val[p],-inf});
}
//printf("%d %d %d\n",x,g[x].x,g[x].y);
ans2=min(ans2,f[x].x+f[x].y);
ans3=max(ans3,g[x].x+g[x].y);
}
}VT;
bool cmp(int x,int y){return dfn[x]<dfn[y];}
void solve()
{
//puts("Enter");
VT.clear();
m=read();
for (int i=1;i<=m;i++) VT.color[a[i]=read()]=1;
sort(a+1,a+m+1,cmp);
//for (int i=1;i<=m;i++) printf("%d ",a[i]);puts("");
int tail=0;
for (int i=1;i<=m;i++)
{
//printf("Time:%d\n",i);
if (!tail) {st[++tail]=a[i];continue;}
int Lca=Get_Lca(st[tail],a[i]),lst=0;
//printf("Lca:%d\n",Lca);
while(dep[Lca]<dep[st[tail]])
{
if (lst) VT.add(lst,st[tail]),VT.add(st[tail],lst);
lst=st[tail];tail--;
}
//printf("lst:%d %d\n",lst,Lca);
if (lst) VT.add(lst,Lca),VT.add(Lca,lst);
if (dep[Lca]!=dep[st[tail]]) st[++tail]=Lca;
st[++tail]=a[i];
}
while(tail!=1) tail--,VT.add(st[tail],st[tail+1]),VT.add(st[tail+1],st[tail]);
VT.rt=a[1];
//puts("Begin_Dfs");
ans1=0,ans2=inf,ans3=-inf;
VT.dfs(VT.rt,0);
for (int i=1;i<=m;i++) VT.color[a[i]]=0;
printf("%lld %d %d\n",ans1,ans2,ans3);
}
int main()
{
n=read();
for (int i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
k=read();
Build_Tree(1,0);
Build_St();
for (int i=1;i<=k;i++) solve();
return 0;
}