原来最长链用的都是两遍dfs求直径或用两个数组存最长链和次长链,今天学习了,如果只需要求出直径大小,别的什么都不用,那么只用一个数组进行dp更加方便。
void DP(int u)
{
if (vis[u]) f1[u]=1,f2[u]=0,f3[u]=0;
else f1[u]=0,f2[u]=1e18,f3[u]=-1e18;
for (register int i=head[u]; i; i=e[i].next)
{
DP(e[i].to);
sum+=(K-f1[e[i].to])*f1[e[i].to]*e[i].w; f1[u]+=f1[e[i].to];
minn=min(minn,f2[u]+f2[e[i].to]+e[i].w); f2[u]=min(f2[u],f2[e[i].to]+e[i].w);
maxn=max(maxn,f3[u]+f3[e[i].to]+e[i].w); f3[u]=max(f3[u],f3[e[i].to]+e[i].w);
}
}
对于最短链,和最长链同理即可。
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+5;
int n,u,v,q,k,K,sum,minn,maxn;
int tot,dfn[N],low[N],size[N],d[N],f[N],son[N],top[N];
int p[N<<1],sta[N],f1[N],f2[N],f3[N];
bool vis[N];
int cnt,head[N];
struct edge{int next,to,w;}e[N<<1];
inline void add(int u,int v,int w)
{
cnt++;
e[cnt].next=head[u];
e[cnt].to=v;
e[cnt].w=w;
head[u]=cnt;
}
void dfs(int u,int fa)
{
dfn[u]=++tot;
size[u]=1;
for (register int i=head[u]; i; i=e[i].next)
if (e[i].to!=fa)
{
d[e[i].to]=d[u]+1; f[e[i].to]=u;
dfs(e[i].to,u);
size[u]+=size[e[i].to];
if (size[e[i].to]>size[son[u]]) son[u]=e[i].to;
}
low[u]=tot;
}
void dfs2(int u,int TP)
{
top[u]=TP;
if (son[u]) dfs2(son[u],TP);
for (register int i=head[u]; i; i=e[i].next)
if (e[i].to!=son[u] && e[i].to!=f[u]) dfs2(e[i].to,e[i].to);
}
inline int lca(int u,int v)
{
while (top[u]!=top[v])
{
if (d[top[u]]<d[top[v]]) swap(u,v);
u=f[top[u]];
}
if (d[u]>d[v]) swap(u,v);
return u;
}
inline bool cmp(int a,int b){return dfn[a]<dfn[b];}
void DP(int u)
{
if (vis[u]) f1[u]=1,f2[u]=0,f3[u]=0;
else f1[u]=0,f2[u]=1e18,f3[u]=-1e18;
for (register int i=head[u]; i; i=e[i].next)
{
DP(e[i].to);
sum+=(K-f1[e[i].to])*f1[e[i].to]*e[i].w; f1[u]+=f1[e[i].to];
minn=min(minn,f2[u]+f2[e[i].to]+e[i].w); f2[u]=min(f2[u],f2[e[i].to]+e[i].w);
maxn=max(maxn,f3[u]+f3[e[i].to]+e[i].w); f3[u]=max(f3[u],f3[e[i].to]+e[i].w);
}
}
signed main(){
scanf("%lld",&n);
for (register int i=1; i<n; ++i) scanf("%lld%lld",&u,&v),add(u,v,0),add(v,u,0);
dfs(1,0); dfs2(1,1);
memset(head,0,sizeof(head));
scanf("%lld",&q);
while (q--)
{
cnt=0;
scanf("%lld",&k); K=k;
for (register int i=1; i<=k; ++i) scanf("%lld",&p[i]),vis[p[i]]=true;
sort(p+1,p+k+1,cmp);
for (register int i=k; i>1; --i) p[++k]=lca(p[i],p[i-1]);
sort(p+1,p+k+1,cmp);
k=unique(p+1,p+k+1)-p-1;
for (register int i=1,top=0; i<=k; ++i)
{
while (top && low[sta[top]]<dfn[p[i]]) top--;
add(sta[top],p[i],d[p[i]]-d[sta[top]]); sta[++top]=p[i];
}
sum=0; minn=1e18; maxn=-1e18;
DP(p[1]);
printf("%lld %lld %lld\n",sum,minn,maxn);
for (register int i=1; i<=k; ++i) head[p[i]]=0,vis[p[i]]=false;
}
return 0;
}