Description
A国有n个城市,城市之间有一些双向道路相连,并且城市两两之间有唯一路径。现在有火车在城市a,需要经过m个城市。火车按照以下规则行驶:每次行驶到还没有经过的城市中在m个城市中最靠前的。现在小A想知道火车经过这m个城市后所经过的道路数量。
Input
第一行三个整数n、m、a,表示城市数量、需要经过的城市数量,火车开始时所在位置。
接下来n-1行,每行两个整数x和y,表示x和y之间有一条双向道路。
接下来一行m个整数,表示需要经过的城市。
接下来n-1行,每行两个整数x和y,表示x和y之间有一条双向道路。
接下来一行m个整数,表示需要经过的城市。
Output
一行一个整数,表示火车经过的道路数量。
Sample Input
5 4 2
1 2
2 3
3 4
4 5
4 3 1 5
Sample Output
9
Data Constraint
题解
- 其实就是按顺序的求lca,求两点之见的经过的道路数量
- 对于要走到的那个点,如果之前到过的,就不用走了
- 没到过的话,lca计算距离,暴力标记路径中的点,并查集维护点不被重复标记
- 对,还有一个,直接dfs建树会炸栈,就bfs罗
代码
1 #include <cstdio> 2 #include <iostream> 3 #include <cstring> 4 using namespace std; 5 struct edge {int to,from;}e[500010*2]; 6 int head[500010],cnt,n,m,a,dep[500010],state[500010],fa[500010]; 7 long long f[500010][21],ans,mi[21]; 8 void insert(int x,int y) { e[++cnt].to=y; e[cnt].from=head[x]; head[x]=cnt; } 9 int getfather(int x) { return fa[x]==x?x:fa[x]=getfather(fa[x]); } 10 void bfs() 11 { 12 int l=0,r=1; 13 state[1]=1,dep[1]=1; 14 while (l<r) 15 { 16 l++; 17 for (int i=head[state[l]];i;i=e[i].from) 18 if (!dep[e[i].to]) 19 { 20 r++; 21 dep[e[i].to]=dep[state[l]]+1; 22 state[r]=e[i].to; 23 f[e[i].to][0]=state[l]; 24 } 25 } 26 } 27 int getlca(int x,int y,long long &ans) 28 { 29 if (dep[x]<dep[y]) swap(x,y); 30 for (int i=20;i>=0;i--) 31 if (dep[f[x][i]]>=dep[y]) 32 ans+=mi[i],x=f[x][i]; 33 if (x==y) return x; 34 for (int i=20;i>=0;i--) 35 if (f[x][i]!=f[y][i]) 36 ans+=mi[i+1],x=f[x][i],y=f[y][i]; 37 ans+=2; 38 return f[x][0]; 39 } 40 int main() 41 { 42 freopen("train.in","r",stdin); 43 freopen("train.out","w",stdout); 44 scanf("%d%d%d",&n,&m,&a); 45 for (int i=1;i<=n-1;i++) 46 { 47 int u,v; 48 scanf("%d%d",&u,&v); 49 insert(u,v),insert(v,u); 50 } 51 bfs(); 52 for (int i=1;i<=n;i++) fa[i]=i; 53 mi[0]=1; for (int i=1;i<=20;i++) mi[i]=mi[i-1]*2; 54 for (int i=1;i<=20;i++) 55 for (int j=1;j<=n;j++) 56 f[j][i]=f[f[j][i-1]][i-1]; 57 for (int i=1;i<=m;i++) 58 { 59 int x; 60 scanf("%d",&x); 61 if (getfather(x)!=x) continue; 62 int lca=getlca(a,x,ans); 63 int u=getfather(a),v=getfather(x); 64 while (u!=v) 65 { 66 if (dep[u]<dep[v]) swap(u,v); 67 u=fa[u]=getfather(f[u][0]); 68 } 69 if (getfather(lca)==lca) fa[lca]=f[lca][0]; 70 a=x; 71 } 72 printf("%lld",ans); 73 return 0; 74 }