H 国有 n 个城市,这 n 个城市用 n-1 条双向道路相互连通构成一棵树,1 号城市是首都,
也是树中的根节点。
H 国的首都爆发了一种危害性极高的传染病。当局为了控制疫情,不让疫情扩散到边境
城市(叶子节点所表示的城市),决定动用军队在一些城市建立检查点,使得从首都到边境
城市的每一条路径上都至少有一个检查点,边境城市也可以建立检查点。但特别要注意的是,
首都是不能建立检查点的。
现在,在 H国的一些城市中已经驻扎有军队,且一个城市可以驻扎多个军队。一支军
队可以在有道路连接的城市间移动,并在除首都以外的任意一个城市建立检查点,且只能在
一个城市建立检查点。一支军队经过一条道路从一个城市移动到另一个城市所需要的时间等
于道路的长度(单位:小时)。
请问最少需要多少个小时才能控制疫情。注意:不同的军队可以同时移动。
第一行一个整数 n,表示城市个数。
接下来的 n-1 行,每行 3 个整数,u、v、w,每两个整数之间用一个空格隔开,表示从
城市 u 到城市 v 有一条长为 w 的道路。数据保证输入的是一棵树,且根节点编号为 1。
接下来一行一个整数 m,表示军队个数。
接下来一行 m 个整数,每两个整数之间用一个空格隔开,分别表示这 m 个军队所驻扎
的城市的编号。
共一行,包含一个整数,表示控制疫情所需要的最少时间。如果无法控制疫情则输出-1。
4
1 2 1
1 3 2
3 4 3
2
2 2
3
【输入输出样例说明】
第一支军队在 2 号点设立检查点,第二支军队从 2 号点移动到 3 号点设立检查点,所需
时间为 3 个小时。
【数据范围】
保证军队不会驻扎在首都。
对于 20%的数据,2≤ n≤ 10;
对于 40%的数据,2 ≤n≤50,0<w <10^5;
对于 60%的数据,2 ≤ n≤1000,0<w <10^6;
对于 80%的数据,2 ≤ n≤10,000;
对于 100%的数据,2≤m≤n≤50,000,0<w <10^9。
题解:二分+贪心+倍增
二分时间t ,表示在t的时间内军队能否控所有点
显然一支军队在到根节点之前 如果能继续向上走 那么这支军队能控制的点就会更多,贡献就越大。
维护mark[i]表示 从该点到所有该点子树的叶子节点路径上是否都有军队
对于每个不能走到根节点的军队 就让他尽量向上走 直到不能走为止 将该点的mark值赋为1,然后dfs一边处理出所有点的mark值。
对于根节点的所有儿子中mark为1的点我们可以不用再考虑,而剩下的儿子他们的子树中存在没有军队覆盖的点,也就是需要别的子树的军队来补充,而覆盖该点一定是代价最小的方案。
为什么呢?首先所有只能在当前子树中移动的军队已经达到了最大贡献,那么只能通过别的子树中的军队来救,而根的儿子节点是需要移动距离最近的节点,而且可以一个军队就搞定。
于是问题转换为 有n个点需要被控制 ,且控制该点需要该点到根的时间 还有m支军队 每支军队有总时限-到根的时间的剩余时间 求这m支军队是否能控制这n个点
将两个数组分别从大到小排序,用两个指针一个个扫就行了
但是要注意的是如果军队j就是从i点走到首都的 那么该军队不论有多少剩余时间都能控制该点
为了解决这个问题,我们记录一个minarm数组,如果军队i经过该点到达首都 那么用minarm[该点]记下这些军队中剩余时间最少的是那只军队 当指针扫到i时先判断minarm[i]是否被用过了 如果没有,那么用minarm[i]来控制i,否则再在j指针上找军队
因为是从大到小排序,所以先匹配的都是时间长度,如果minarm[i]没被使用 那么t[j]>=t[minarm[i]] 所以在i点如果用minarm[i]来覆盖,就能省下一个剩余时间更多的军队 对于之后可能用到minarm[i]的点完全可以用j点代替
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 100003
#define inf 1000000000
#define LL long long
using namespace std;
int n,m,tot;
int point[N],next[N],v[N],sol[N],deep[N],f[20][N];
LL mi[20],g[20][N],c[N],len[N];
int in[N],out[N],mark[N],use[N];
int belong[N];
struct data
{
int x,left;
}p[N],q[N],minarm[N];
int cmp(data a,data b)
{
return a.left>b.left;
}
void add(int x,int y,int z)
{
tot++; next[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z;
tot++; next[tot]=point[y]; point[y]=tot; v[tot]=x; c[tot]=z;
}
void dfs(int x,int fa)
{
deep[x]=deep[fa]+1;
for (int i=1;i<=13;i++)
{
if (deep[x]-mi[i]<0) break;
f[i][x]=f[i-1][f[i-1][x]];
g[i][x]=g[i-1][x]+g[i-1][f[i-1][x]];
}
for (int i=point[x];i;i=next[i])
if(v[i]!=fa)
{
out[x]++;
in[v[i]]++;
f[0][v[i]]=x;
g[0][v[i]]=c[i];
len[v[i]]=len[x]+c[i];
if (fa) belong[v[i]]=belong[x];
else belong[v[i]]=v[i];
dfs(v[i],x);
}
}
int dfs1(int x,int fa)
{
if (mark[x]) return mark[x];
int t=1;
for (int i=point[x];i;i=next[i])
if (v[i]!=fa)
t=min(t,dfs1(v[i],x)),mark[x]=t;
return mark[x];
}
bool check(LL x)
{
memset(mark,0,sizeof(mark));
for (int i=1;i<=n;i++)
minarm[i].left=inf,minarm[i].x=-1;
int cnt=0,cnt1=0;
for (int i=1;i<=m;i++)
if (len[sol[i]]<x)
{
q[++cnt].left=x-len[sol[i]],q[cnt].x=i;
if (minarm[belong[sol[i]]].left>x-len[sol[i]])
minarm[belong[sol[i]]].left=x-len[sol[i]],
minarm[belong[sol[i]]].x=i;
}
else
{
if (f[0][sol[i]]==1)
{
mark[sol[i]]=1;
continue;
}
mark[sol[i]]=1;
LL t=0; int now=sol[i];
for (int j=13;j>=0;j--)
while (t+g[j][now]<=x&&g[j][now]!=0&&f[j][now]!=1)
t+=g[j][now],now=f[j][now],mark[now]=1;
}
dfs1(1,0);
for (int i=point[1];i;i=next[i])
if (!mark[v[i]])
p[++cnt1].left=c[i],p[cnt1].x=v[i];
//cout<<cnt1<<" "<<cnt<<" "<<x<<endl;
if (cnt1>cnt) return false;
sort(p+1,p+cnt1+1,cmp);
sort(q+1,q+cnt+1,cmp);
memset(use,0,sizeof(use));
int j=1; int i=1;
while (j<=cnt&&i<=cnt1)
{
if (!use[minarm[p[i].x].x]&&minarm[p[i].x].x!=-1)
use[minarm[p[i].x].x]=1,i++;
else
{
while (use[q[j].x]) j++;
if (q[j].left>=p[i].left)
i++,use[q[j].x]=1,j++;
else return false;
}
}
if (i!=cnt1+1) return false;
return true;
}
int main()
{
scanf("%d",&n);
LL sum=0; int k=0;
for (int i=1;i<=n-1;i++)
{
int x,y; LL z; scanf("%d%d%lld",&x,&y,&z);
add(x,y,z); sum+=z;
if (x==1||y==1) k++;
}
mi[0]=1; for (int i=1;i<=13;i++) mi[i]=mi[i-1]*2;
dfs(1,0);
scanf("%d",&m);
for (int i=1;i<=m;i++)
scanf("%d",&sol[i]),mark[sol[i]]=1;
if (m<k) {
printf("-1\n");
return 0;
}
if (dfs1(1,0))
{
printf("0\n");
return 0;
}
LL l=0; LL r=sum;
LL ans=sum;
while (l<=r)
{
LL mid=(l+r)/2;
if (check(mid)) ans=min(ans,mid),r=mid-1;
else l=mid+1;
}
printf("%lld\n",ans);
}