题意:给一棵树,树上有一些关键节点,选m个点,使得关键节点到这些点中距离的最小值的最大值最小
最大值最小,果断二分答案
我们只需要判定是否存在m个点能够在mid范围内到达所有关键点
暴力:从每个点bfs一遍看看mid范围能是否能覆盖到所有的点,o(n^2logn)
发现可以贪心,一个关键点要么被它的子树内的点管理,要么被它子树外的点管理,于是我们记录个pair/struct
first表示以x为根的子树中目前还没有人管理的关键点距离x的最远的距离,second表示以x为根的子树中选择了的点距离x的最近的距离.
①if(first+second<=mid)以x为根的树是可以自己处理的
②if(first==mid)就意味着必须要选择x这个点了
因为再向上一个点距离就超过mid了,这时候强制选择x这个点,并更新first,second即可
③if(这个点是关键点&&second>mid)就要更新first了
注意出来的时候要特判1(树根)
#include <stdio.h>
#include <cstdlib>
#include <algorithm>
#include <cstring>
#include <time.h>
#pragma warning(disable:4996)
template<typename T> T min(T x, T y)
{
return x < y ? x : y;
}
template<typename T> T max(T x, T y)
{
return x > y ? x : y;
}
const int MAXN = 300005;
const int B = 400;
const int INF = 2000000005;
struct node {
int to;
node *next;
};
void addnode(node *&head, int to)
{
node *p = new node;
p->to = to;
p->next = head;
head = p;
}
int N, M;
node *edge[MAXN];
int deep[MAXN], fa[MAXN][25], key[MAXN], cnt;
int rank[MAXN], st[MAXN * 2][25], len;
int f[MAXN], list[MAXN], num;
bool mark[MAXN];
bool cmp(const int u, const int v)
{
return deep[u] > deep[v];
}
int anc(int x, int k)
{
for (int i = 0; i < 25; i++)
if (k&(1 << i))
x = fa[x][i];
return x;
}
int LCA(int x, int y)
{
x = rank[x];
y = rank[y];
if (x > y)
std::swap(x, y);
int len = y - x + 1;
int t = 0;
while (1 << t <= len)
t++;
t--;
y = y - (1 << t) + 1;
return deep[st[x][t]] > deep[st[y][t]]? st[y][t] : st[x][t];
}
int dis(int x, int y)
{
return deep[x] + deep[y] - 2 * deep[LCA(x, y)];
}
int nearest(int v)
{
int x = f[v];
for (int i = 1; i <= num; i++)
x = min(x, dis(v, list[i]));
return x;
}
void dfs1(int v)
{
f[v] = mark[v]? 0: INF;
for (node *p = edge[v]; p; p = p->next)
if (p->to != fa[v][0])
{
dfs1(p->to);
f[v] = min(f[v], f[p->to] + 1);
}
}
void dfs2(int v)
{
f[v] = min(f[v], f[fa[v][0]] + 1);
for (node *p = edge[v]; p; p = p->next)
if (p->to != fa[v][0])
dfs2(p->to);
}
void insert(int v)
{
list[++num] = v;
if (num == B)
{
for (int i = 1; i <= num; i++)
mark[list[i]] = true;
dfs1(1);
dfs2(1);
num = 0;
}
}
bool judge(int d)
{
int i, n = 0;
num = 0;
memset(mark, 0, sizeof(mark));
memset(f, 63, sizeof(f));
for (i = 1; i <= cnt; i++)
{
if (nearest(key[i]) > d)
{
n++;
insert(anc(key[i], min(d, deep[key[i]])));
}
}
return n <= M;
}
void dfs(int v)
{
st[++len][0] = v;
rank[v] = len;
for (int i = 1; i < 25; i++)
fa[v][i] = fa[fa[v][i - 1]][i - 1];
for (node *p = edge[v]; p; p = p->next)
if (p->to != fa[v][0])
{
fa[p->to][0] = v;
deep[p->to] = deep[v] + 1;
dfs(p->to);
st[++len][0] = v;
}
}
void init()
{
int i, j, u, v;
scanf("%d %d", &N, &M);
for (i = 1; i <= N; i++)
{
scanf("%d", &u);
if (u)
key[++cnt] = i;
}
for (i = 1; i < N; i++)
{
scanf("%d %d", &u, &v);
addnode(edge[u], v);
addnode(edge[v], u);
}
dfs(1);
std::sort(key + 1, key + cnt + 1, cmp);
deep[0] = INF;
for (i = 1; i < 25; i++)
{
int r = min(1 << (i - 1), len);
for (j = 1; j <= len; j++)
{
if (r < len)
r++;
st[j][i] = cmp(st[j][i - 1], st[r][i - 1]) ?
st[r][i - 1] : st[j][i - 1];
}
}
}
int main()
{
int l = -1, r = MAXN;
init();
while (r - l > 1)
{
int mid = (l + r) / 2;
if (judge(mid))
r = mid;
else
l = mid;
}
printf("%d\n", r);
return 0;
}