在树上走显然是求LCA,然后每次走完把端点到LCA路径上的点都用并查集合并,之后如果判断两点所属集合相等说明已经走过。
《倍增被链剖虐成狗系列》
#include<iostream>
#include<cstdio>
#define N 500005
#define ll long long
using namespace std;
int n,m,now,cnt;
ll ans;
int head[N],next[N<<1],list[N<<1],deep[N],f[N],fa[N][20];
inline int read()
{
int a=0,f=1; char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1; c=getchar();}
while (c>='0'&&c<='9') {a=a*10+c-'0'; c=getchar();}
return a*f;
}
inline void insert(int x,int y)
{
next[++cnt]=head[x];
head[x]=cnt;
list[cnt]=y;
}
void dfs(int x)
{
for (int i=1;(1<<i)<=deep[x];i++)
fa[x][i]=fa[fa[x][i-1]][i-1];
for (int i=head[x];i;i=next[i])
if (list[i]!=fa[x][0])
{
fa[list[i]][0]=x;
deep[list[i]]=deep[x]+1;
dfs(list[i]);
}
}
inline int lca(int x,int y)
{
if (deep[x]<deep[y]) swap(x,y);
int t=deep[x]-deep[y];
for (int i=0;(1<<i)<=t;i++)
if ((1<<i)&t) x=fa[x][i];
for (int i=19;~i;i--)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return x==y?x:fa[x][0];
}
int find(int i)
{
return f[i]==i?i:f[i]=find(f[i]);
}
int main()
{
n=read(); m=read(); now=read();
for (int i=1;i<=n;i++) f[i]=i;
for (int i=1;i<n;i++)
{
int u=read(),v=read();
insert(u,v); insert(v,u);
}
dfs(1);
for (int i=1;i<=m;i++)
{
int next=read();
int p=find(now),q=find(next);
if (p==q) continue;
int t=lca(now,next),h=find(t);
ans+=deep[now]+deep[next]-(deep[t]<<1);
int x,y;
x=now;
while (find(x)!=h)
{
y=find(x);
f[y]=h;
x=fa[y][0];
}
x=next;
while (find(x)!=h)
{
y=find(x);
f[y]=h;
x=fa[y][0];
}
now=next;
}
cout << ans << endl;
return 0;
}