思路
二分答案+贪心是肯定的,难的是check()函数,参考了神犇的思路,给定时间k,对于每一支军队往在k时间内能到达的最高点走,这样会得到最优解,先预处理出每支军队到达根节点的时间,每次二分时判断多少军队能到达根节点(这些军队可以调动到没有被控制的节点上),然后处理出没有达到根节点的军队最高能到达的节点,将能到达的最高节点打标记;用DFS将标记上传,也就是说如果该节点的所有儿子都被打了标记,那该节点也会被控制;建立两个结构体,一个记录可以调动的军队,另一个记录尚未被控制的节点,按照降序排序,注意一点,如果军队所在的节点没有被控制的话就让他自己去控制好了,其他的只要剩余时间大于到达时间贪心就可以了;
Code
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int MAXN=50000+10;
int head[MAXN],p[MAXN][20];
int d[MAXN],g[MAXN][20];
int arm[MAXN],r[MAXN];
bool vis[MAXN];
int n,u,v,w,m,num,sum;
struct Edge{
int next,to,w;
}edge[MAXN<<1];
struct Point{
int w,from;
}b[MAXN<<1],c[MAXN<<1];
void add(int from,int to,int w)
{
edge[++num].to=to;
edge[num].w=w;
edge[num].next=head[from];
head[from]=num;
}
void dfs(int u)
{
for(int i=head[u];i;i=edge[i].next)
if(!d[edge[i].to])
{
int to=edge[i].to;
d[to]=d[u]+1;
g[to][0]=edge[i].w;
p[to][0]=u;
dfs(to);
}
}
void init()
{
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i<=n;i++)
if(p[i][j-1])
{
p[i][j]=p[p[i][j-1]][j-1];
g[i][j]=g[i][j-1]+g[p[i][j-1]][j-1];
}
}
void work()
{
for(int i=1;i<=m;i++)
{
int f=d[arm[i]]-d[1],x=arm[i];
for(int j=(int)log2(n);j>=0;j--)
if((1<<j)&f) r[i]+=g[x][j],x=p[x][j];
}
}
void pro(int i,int res)
{
int time=0;
for(int j=(int)log2(n);j>=0;j--)
if(p[i][j]&&g[i][j]+time<=res)
{
time+=g[i][j];
i=p[i][j];
}
vis[i]=1;
}
void pushup(int u)
{
int p1=1,q=0;
for(int i=head[u];i;i=edge[i].next)
if(edge[i].to!=p[u][0])
{
pushup(edge[i].to);
p1=p1&vis[edge[i].to];
q=1;
}
if(p1&&q&&u!=1) vis[u]=1;
}
int cmp(Point a,Point b) {return a.w<b.w;}
int check(int time)
{
memset(vis,0,sizeof vis);
int cnt=0,top=0;
for(int i=1;i<=m;i++)
if(r[i]>time) pro(arm[i],time);
else {
int y=arm[i];
b[++cnt].w=time-r[i];
for(int j=(int)log2(n);j>=0;j--)
if(p[y][j]>1) y=p[y][j];
b[cnt].from=y;
}
pushup(1);
for(int i=head[1];i;i=edge[i].next)
if(!vis[edge[i].to])
{
c[++top].from=edge[i].to;
c[top].w=edge[i].w;
}
sort(b+1,b+cnt+1,cmp);
sort(c+1,c+top+1,cmp);
int j=1;c[top+1].w=0x7fffffff;
for(int i=1;i<=cnt;i++)
{
if(!vis[b[i].from]) vis[b[i].from]=1;
else if(b[i].w>=c[j].w) vis[c[j].from]=1;
while(vis[c[j].from]) j++;
}
if(j>top) return 1;
return 0;
}
int main()
{
freopen("01.in","r",stdin);
scanf("%d",&n);
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&u,&v,&w);
sum+=w;
add(u,v,w);
add(v,u,w);
}
d[1]=1;
dfs(1);
init();
scanf("%d",&m);
for(int i=1;i<=m;i++)
scanf("%d",&arm[i]);
work();
int l=0,r=sum,ans=0;
while(l<=r)
{
int m=(l>>1)+(r>>1)+(l&r&1);
if(check(m)) ans=m,r=m-1;
else l=m+1;
}
printf("%d",ans);
return 0;
}