据说这道题有ON做法
好像是真的
但是没人写
我也不会
所以我就写了正常的
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)做法啦
显然要二分一下对吧
然后我们怎么判断呢?
考虑树形DP:
f
[
u
]
f[u]
f[u]表示以
u
u
u为根节点的子树中,最远的没有被覆盖到的关键节点的距离
g
[
u
]
g[u]
g[u]表示以
u
u
u为根节点的自述中,最近的一个选择的点的距离
那么转移显然就是
f
[
u
]
=
max
{
f
[
v
]
+
1
}
f[u]=\max\{f[v]+1\}
f[u]=max{f[v]+1}
g
[
u
]
=
min
{
g
[
v
]
+
1
}
g[u]=\min\{g[v]+1\}
g[u]=min{g[v]+1}
但是还有别的情况啊,要不然答案就是0了啊
情况1
当
f
[
u
]
+
g
[
u
]
≤
δ
(
二
分
的
值
)
f[u]+g[u]\leq\delta(二分的值)
f[u]+g[u]≤δ(二分的值)时,让
u
u
u子树中所有未匹配的节点连向
g
[
u
]
g[u]
g[u]那个点就可以,
f
[
u
]
f[u]
f[u]清零
情况2
当
f
[
u
]
=
δ
f[u]=\delta
f[u]=δ时,需要选择当前的点,
f
[
u
]
f[u]
f[u]清零,
g
[
u
]
=
0
g[u]=0
g[u]=0
情况3
当我们做到1的时候,如果还有没覆盖的,就要把1选上
那么为了便于处理,当
f
[
u
]
f[u]
f[u]清零的时候,我们让
f
[
u
]
=
−
1
f[u]=-1
f[u]=−1,当
g
[
u
]
g[u]
g[u]清零的时候(当下面还没有选过的时候),我们让
g
[
u
]
=
i
n
f
g[u]=inf
g[u]=inf
然后还有一些细节,比如说二分的判断之类的
我不会告诉你我在这上面调了一个小时因为我对拍的暴力写挂了
#include <bits/stdc++.h>
using namespace std;
# define Rep(i,a,b) for(int i=a;i<=b;i++)
# define _Rep(i,a,b) for(int i=a;i>=b;i--)
# define RepG(i,u) for(int i=head[u];~i;i=e[i].next)
typedef long long ll;
const int N=3e5+5;
template<typename T> void read(T &x){
x=0;int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+c-'0';
x*=f;
}
int n,m,ans,tot;
int head[N],cnt;
int f[N],g[N];
int col[N];
struct Edge{
int to,next;
}e[N<<1];
void add(int x,int y){
e[++cnt]=(Edge){y,head[x]},head[x]=cnt;
}
void dfs(int u,int fa,int delta){
if(!col[u])f[u]=-1;
else f[u]=0;
g[u]=2e9;
RepG(i,u){
int v=e[i].to;
if(v==fa)continue;
dfs(v,u,delta);
if(f[v]!=-1)f[u]=max(f[u],f[v]+1);
if(g[v]!=2e9)g[u]=min(g[u],g[v]+1);
}
if(f[u]!=-1){
if(g[u]==2e9){
if(f[u]==delta){
f[u]=-1;
tot++;
g[u]=0;
}
}
else{
if(g[u]+f[u]<=delta)f[u]=-1;
else if(f[u]==delta){
f[u]=-1;
tot++;
g[u]=0;
}
}
}
}
bool check(int delta){
tot=0;
dfs(1,0,delta);
if(f[1]!=-1)tot++;
return tot<=m;
}
int main()
{
memset(head,-1,sizeof(head));
read(n),read(m);
Rep(i,1,n)read(col[i]);
Rep(i,1,n-1){
int x,y;
read(x),read(y);
add(x,y),add(y,x);
}
int l=0,r=n;
while(l<=r){
int mid=l+r>>1;
if(check(mid))ans=mid,r=mid-1;
else l=mid+1;
}
// check(1);
printf("%d\n",ans);
return 0;
}