测试时已经想到是二分了 可是想了很长时间贪心最后都被自己否定了
正解是 dp 不难想的样子?
有二分的代码不好调......
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int MAXN = 2e5+10;
int tot,g[MAXN],num[MAXN*2],nnext[MAXN*2];
void add(int x,int y)
{
tot++;
nnext[tot]=g[x];
g[x]=tot;
num[tot]=y;
}
int f[MAXN];
int team[MAXN],head,tail;
int in[MAXN],tin[MAXN];
int fa[MAXN];
bool is[MAXN];
int n,k;
int init()
{
head=tail=0;
team[++tail]=1;
while(head<tail)
{
int x=team[++head];//cout<<x<<endl;
for(int i=g[x];i;i=nnext[i])
{
int tmp=num[i];
if(tmp==fa[x]) continue ;
fa[tmp]=x;
in[x]++;
team[++tail]=tmp;
}
}
f[0]=1e9;
}
bool ok(int k)
{
head=tail=0;
for(int i=1;i<=n;i++)
{
tin[i]=in[i];
if(tin[i]==0) team[++tail]=i;
f[i]=0;
}
while(head<tail)
{
int x=team[++head];
// cout<<x<<endl;
if(is[x]==true)
{
for(int i=g[x];i;i=nnext[i])
{
int tmp=num[i];//cout<<tmp<<' '<<f[tmp]<<endl;
if(tmp==fa[x]) continue ;
if(f[tmp]<0) f[x]-=f[tmp];
}
f[x]++;
// cout<<x<<' '<<f[x]<<endl;
if(f[x]>k) return false;
}
else
{
int tx=0;
int sum=0;
for(int i=g[x];i;i=nnext[i])
{
int tmp=num[i];
if(tmp==fa[x]) continue ;
if(f[tmp]>0)
{
if(f[tmp]<f[tx]) tx=tmp;
}
else
sum-=f[tmp];
}
//cout<<x<<' '<<tx<<endl;
if(f[tx]+sum+1<=k) f[x]=f[tx]+sum+1;
else if(sum+1<k) f[x]=-sum-1;
else return false;
}
if(--tin[fa[x]]==0) team[++tail]=fa[x];
}
// cout<<f[4]<<endl;
if(f[1]>0) return true;
return false;
}
int main()
{
// freopen("a.in","r",stdin);
// freopen("wa.out","w",stdout);
cin >>n >>k;
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d %d",&x,&y);
add(x,y);
add(y,x);
}
for(int i=1;i<=k;i++)
{
int x;
scanf("%d",&x);
is[x]=true;
}
init();
// cout<<ok(3)<<endl;
int L=1,R=n;
while(L<R)
{
// cout<<L<<' '<<R<<endl;
int mid=(L+R)/2;
// cout<<mid<<endl;//cout<<ok(4)<<endl;
if(ok(mid)) R=mid;
else L=mid+1;
}
cout<<L<<endl;
return 0;
}