记得这题还有个兄弟——poi2009 gas,至今还wa一个点。。。。
不管了,至少看上去这题比他兄弟和谐一点
首先毋庸置疑,肯定是二分答案。
O(n^2)的检查很好想吧。
从下向上扫描,如果一个炸弹暂时没有在规定时间内被引爆,那么一定是在尽可能上的地方放一个炸弹。——直接写就是O(n^2)了。。。。
程序片段
bool Judge(){
memset(p,-1,sizeof(p));
for (i=n,cnt=0;i>0;i--){
x=q[i];
if (!bm[x]) continue;
for (j=x,k=0;j>0&&k<=mid;k++,j=fa[j])
if (p[j]>=k) break;
if (j>0&&k<=mid) continue;
cnt++;
if (cnt>m) return 0;
for (j=x,k=0;j>1&&k<mid;k++,j=fa[j]);
for (k=mid;j>0&&k>=0;j=fa[j],k--)
if (p[j]<k) p[j]=k; else break;
}
return 1;
}
怎么加速检查?
表示被题解彻底鄙视了。。。。。
fire[]表示此子树向下最近的点火距离。
bomb[]表示此子树向下向上最远的未被引爆的炸弹。
两个数组直接转移就好了。。。。。
写起来比暴力差不多
bool Judge(){
memset(fire,127/2,sizeof(fire));
for (i=1;i<=n;i++)
if (bm[i]==1) bomb[i]=0;
else bomb[i]=-fire[0];
for (i=n,cnt=0;i>0;i--){
x=q[i];
if (fire[x]>mid) fire[x]=fire[0];
if (bomb[x]+fire[x]<=mid) bomb[x]=-fire[0];
if (bomb[x]==mid)
bomb[x]=-fire[0], fire[x]=0, cnt++;
if (cnt>m) return 0;
fire[ fa[x] ]=min(fire[ fa[x] ],fire[x]+1);
bomb[ fa[x] ]=max(bomb[ fa[x] ],bomb[x]+1);
}
if (bomb[1]>=0) cnt++;
return cnt<=m;
}
最后附上完整代码,在膜拜一下orzorzorz
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int Maxn=300005;
int node[Maxn*2],next[Maxn*2],a[Maxn],p[Maxn],q[Maxn];
int n,m,mid,ans,tot,l,r,i,j,k,cnt,x,y,fa[Maxn],bm[Maxn];
void add(int x,int y){
node[++tot]=y; next[tot]=a[x]; a[x]=tot;
node[++tot]=x; next[tot]=a[y]; a[y]=tot;
}
int dep[Maxn];
void bfs(){
for (q[l=r=1]=1;l<=r;l++)
for (i=a[q[l]];i;i=next[i])
if (node[i]!=fa[q[l]]){
fa[ q[++r]=node[i] ]=q[l];
dep[ q[r] ]=dep[ q[l] ]+1;
}
}
// bomb 子树下最远的炸药
// fire 子树下最近的点燃处
int fire[Maxn], bomb[Maxn];
bool Judge(){
memset(fire,127/2,sizeof(fire));
for (i=1;i<=n;i++)
if (bm[i]==1) bomb[i]=0;
else bomb[i]=-fire[0];
for (i=n,cnt=0;i>0;i--){
x=q[i];
if (fire[x]>mid) fire[x]=fire[0];
if (bomb[x]+fire[x]<=mid) bomb[x]=-fire[0];
if (bomb[x]==mid)
bomb[x]=-fire[0], fire[x]=0, cnt++;
if (cnt>m) return 0;
fire[ fa[x] ]=min(fire[ fa[x] ],fire[x]+1);
bomb[ fa[x] ]=max(bomb[ fa[x] ],bomb[x]+1);
}
if (bomb[1]>=0) cnt++;
return cnt<=m;
}
int main(){
freopen("dyn.in","r",stdin);
freopen("dyn.out","w",stdout);
scanf("%d%d",&n,&m);
for (i=1;i<=n;i++) scanf("%d",&bm[i]);
for (i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);
}
bfs();
l=0; r=(n+1)/2;
ans=r;
while (l<=r){
mid=(l+r)>>1;
if (Judge()) ans=mid, r=mid-1;
else l=mid+1;
}
printf("%d\n",ans);
return 0;
}