统计每一条边的贡献,假设$u$是$v$的父节点,$(u,v)$的贡献为:$v$下面大学个数$f[v]$与$2*k-f[v]$的较小值。
#pragma comment(linker, "/STACK:1024000000,1024000000") #include<cstdio> #include<cstring> #include<cmath> #include<algorithm> #include<vector> #include<map> #include<set> #include<queue> #include<stack> #include<iostream> using namespace std; typedef long long LL; const double pi=acos(-1.0),eps=1e-8; void File() { freopen("D:\\in.txt","r",stdin); freopen("D:\\out.txt","w",stdout); } template <class T> inline void read(T &x) { char c = getchar(); x = 0;while(!isdigit(c)) c = getchar(); while(isdigit(c)) { x = x * 10 + c - '0'; c = getchar(); } } const int maxn=200010; struct Edge { int u,v,nx; }e[2*maxn]; int h[maxn],sz; int n,m,f[maxn],dis[maxn]; LL sum; void dfs(int x,int fa) { for(int i=h[x];i!=-1;i=e[i].nx) { if(fa==e[i].v) continue; dfs(e[i].v,x); f[x]=f[x]+f[e[i].v]; } for(int i=h[x];i!=-1;i=e[i].nx) { if(fa==e[i].v) continue; sum=sum+min(f[e[i].v],2*m-f[e[i].v]); } } void add(int u,int v) { e[sz].u=u; e[sz].v=v; e[sz].nx=h[u]; h[u]=sz++; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=2*m;i++) { int x; scanf("%d",&x); f[x]=1; } memset(h,-1,sizeof h); for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); add(u,v); add(v,u); } dfs(1,0); printf("%lld\n",sum); return 0; }