Description
给定一棵有
n
个节点并且以1为根的树,根的深度为1。给定
Solution
首先我们要知道一个重要的性质:对于一个连续块,它最终的
LCA
一定可以由这个连续块内的相邻两个元素求
LCA
得出。这里简单证明一下:我们不妨设一个连续块的
LCA
为
A
(这里假定
我们定义
dp[i][j]
表示遍历到第
i
个点时已经形成了
- 留着
i
这个点,让它和后面的点形成一个连续块,即
dp[i][j]=min(dp[i][j],dp[i−1][j]) ; -
i
这个点独立形成一个连续块,即
dp[i][j]=min(dp[i][j],dp[i−1][j−1]+dep[id[i]]) ; - 让
i
和
i−1 的 LCA 成为他们所在的连续块的 LCA (虽然这两个点的 LCA 不一定是他们所在的连续块的 LCA ,但是根据上面求得的性质我们知道这个连续块的 LCA 一定会被更新到),即 dp[i][j]=min(dp[i][j],dp[i−2][j−1]+dep[LCA(id[i−1],id[i])]) ;
接下来可以考虑 dp 的优化了。时间上,我们发现一对相邻的点的 LCA 在转移时可能被用到多次,我们可以先预处理出所有相邻点的 LCA ,再直接拿过来用。空间上,可以考虑滚动,但还有个方法更优,我们考虑到 n∗k≤3∗105 ,那么我们只要开个一维数组即可。
Code
#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<iostream>
#define M 300005
using namespace std;
template <class T>
inline void Rd(T &res){
char c;res=0;int k=1;
while(c=getchar(),c<48&&c!='-');
if(c=='-'){k=-1;c='0';}
do{
res=(res<<3)+(res<<1)+(c^48);
}while(c=getchar(),c>=48);
res*=k;
}
template <class T>
inline void Pt(T res){
if(res<0){
putchar('-');
res=-res;
}
if(res>=10)Pt(res/10);
putchar(res%10+48);
}
void check(int &a,int b){
if(a==-1||a>b)a=b;
}
struct edge{
int v,nxt;
}e[M<<1];
int n,k,A[M];
int head[M],edgecnt;
int dep[M],fa[M],top[M],L[M],R[M],sz[M],son[M],tim;
void add_edge(int u,int v){
e[++edgecnt].v=v;e[edgecnt].nxt=head[u];head[u]=edgecnt;
}
void dfs(int x,int t){
if(~t)dep[x]=dep[t]+1;
fa[x]=t;
L[x]=++tim;
sz[x]=1;
for(int i=head[x];~i;i=e[i].nxt){
int v=e[i].v;
if(v==t)continue;
dfs(v,x);
sz[x]+=sz[v];
if(sz[v]>sz[son[x]])son[x]=v;
}
R[x]=tim;
}
void rdfs(int x,int t,int tp){
top[x]=tp;
for(int i=head[x];~i;i=e[i].nxt){
int v=e[i].v;
if(v==t)continue;
if(v==son[x])rdfs(v,x,tp);
else rdfs(v,x,v);
}
}
int LCA(int u,int v){
while(top[u]!=top[v]){
if(dep[top[u]]>dep[top[v]])u=fa[top[u]];
else v=fa[top[v]];
}
return dep[u]<dep[v]?u:v;
}
int dp[M+60000];
int lca[M];
void solve(){
memset(dp,-1,sizeof(dp));
dp[0]=0;
for(int i=1;i<n;i++)
lca[i]=LCA(A[i],A[i+1]);
for(register int i=1;i<=n;++i){
int pre2=(i-2)*(k+1);
int pre1=(i-1)*(k+1);
int now=i*(k+1);
for(register int j=0;j<=k;++j){
check(dp[now+j],dp[pre1+j]);
if(j>=1){
if(~dp[pre1+j-1])check(dp[now+j],dp[pre1+j-1]+dep[A[i]]);
if(i>=2&&~dp[pre2+j-1])check(dp[now+j],dp[pre2+j-1]+dep[lca[i-1]]);
}
}
}
}
int main(){
int a,b;
while(scanf("%d",&n)!=EOF){
memset(head,-1,sizeof(head));
memset(son,0,sizeof(son));
edgecnt=0;tim=0;
dep[1]=1;
Rd(k);
for(int i=1;i<=n;++i)Rd(A[i]);
for(int i=1;i<n;++i){
Rd(a);Rd(b);
add_edge(a,b);
add_edge(b,a);
}
dfs(1,-1);
rdfs(1,-1,1);
solve();
Pt(dp[n*(k+1)+k]);
putchar('\n');
}
return 0;
}