首先我们先求出每个点到叶节点的最大距离。从编号为1的节点开始,存在一个数组中。
然后在用rmq预处理这个数组的最大值和最小值
最后二分查找答案。
现在我们讨论如何解决这3步
第一步:树形dp,需要两次遍历。第一次遍历,任取一个节点为根节点u,树的层次就确定了。可以递归求出每个点到其子孙的最远距离。每个节点记录两个最值,最大值和次大值,并记录这两个值来自哪个节点。注意:这两个值来自不同的子树中。然后进行第二次遍历,从u开始遍历子节点。子节点的最值要么来自其子孙中,要么来自其父节点的最值中。这里我们可以用父节点的最值来更新子节点的最值。若父节点最值来自这个子节点,则用次大值更新,否则用最大值更新。最后我们把每个节点的最大值按编号从1-n存在一个数组中。
第二步:rmq预处理出数组的最大值和最小值,这一步不难。
第三部:题目要求差值不大于Q,所以我们枚举人数,看是否符合,取最大值。
复杂度分析:
第一步求最大距离,2次遍历都是O(n),第二部rmq预处理为O(nlgn),第三部查询最多有500,二分答案lgn,然后枚举起点需要O(n-m),m为长度,最坏情况为:500*lgn*(n-m)。
优化:
二分查找时,每次每个长度求出后即记录下来,下一次再访问时可以避免枚举,直接得到答案。不加会TLE。
Rmq查询时,计算区间长度的log值,以前都调用函数库的log函数,数学函数处理double值,速度一般比较慢,在多查询中体现的明显,后来预处理出来,然后就过了,不然会TLE。
程序最后优化的比较快,265MS,在杭电排Rank1
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int N = 51000;
const int M = 17;
const int eps = 1e-9;
int d[N][2], Fm[N][2], head[N];
int mi[M][N], mx[M][N];
int Log[N];
int ans[N];
struct EDG{
int u, v, c, next;
}g[N * 2];
int cnt, n;
void add(int u, int v, int c)
{
g[cnt].u = u; g[cnt].v = v; g[cnt].c = c; g[cnt].next = head[u]; head[u] = cnt++;
g[cnt].u = v; g[cnt].v = u; g[cnt].c = c; g[cnt].next = head[v]; head[v] = cnt++;
}
int dfs(int cur, int p)
{
for(int e = head[cur]; e != -1; e = g[e].next){
int v = g[e].v;
if(v != p){
int t = dfs(v, cur) + g[e].c;
if(d[cur][1] < t){
d[cur][0] = d[cur][1]; Fm[cur][0] = Fm[cur][1];
d[cur][1] = t; Fm[cur][1] = v;
}else if(d[cur][0] < t){
d[cur][0] = t; Fm[cur][0] = v;
}
}
}
return d[cur][1];
}
void Dp(int cur, int p, int c)
{
if(cur != Fm[p][1]){
if(d[cur][1] < d[p][1] + c){
d[cur][0] = d[cur][1]; Fm[cur][0] = Fm[cur][1];
d[cur][1] = d[p][1] + c; Fm[cur][1] = p;
}else if(d[cur][0] < d[p][1] + c){
d[cur][0] = d[p][1] + c; Fm[cur][0] = p;
}
}else {
if(d[cur][1] < d[p][0] + c){
d[cur][0] = d[cur][1]; Fm[cur][0] = Fm[cur][1];
d[cur][1] = d[p][0] + c; Fm[cur][1] = p;
}else if(d[cur][0] < d[p][0] + c){
d[cur][0] = d[p][0] + c; Fm[cur][0] = p;
}
}
for(int e = head[cur]; e != -1; e = g[e].next){
int v = g[e].v;
if(v != p){
Dp(v, cur, g[e].c);
}
}
}
void ST()
{
for(int i = 1; i <= n; i++)
mi[0][i] = mx[0][i] = d[i][1];
for(int i = 1; (1 << i) < n; i++){
for(int j = n; j >= 1; j--){
mx[i][j] = mx[i - 1][j];
if(j + (1 << (i - 1)) <= n)
mx[i][j] = max(mx[i][j], mx[i - 1][j + (1 << (i - 1))]);
mi[i][j] = mi[i - 1][j];
if(j + (1 << (i - 1)) <= n)
mi[i][j] = min(mi[i][j], mi[i - 1][j + (1 << (i - 1))]);
}
}
}
inline int rmq(int l, int r)
{
int m = Log[r - l + 1];
int a = max(mx[m][l], mx[m][r - (1 << m) + 1]);
int b = min(mi[m][l], mi[m][r - (1 << m) + 1]);
return a - b;
}
inline int check(int m)
{
int &res = ans[m];
if(res != -1) return res;
for(int i = 1; i + m - 1 <= n; i++){
if(res == -1) res = rmq(i, i + m - 1);
else res = min(res, rmq(i, i + m - 1));
}
return res;
}
int main()
{
//freopen("input.txt", "r", stdin);
int m;
int u, v, c;
for(int i = 0, j = 0; i <= 50000; i++){
if((1 << j) >= ((i + 1) >> 1)){
Log[i] = j;
}else{
Log[i] = ++j;
}
}
while(scanf("%d %d", &n, &m) == 2 && m + n){
cnt = 0;
memset(head, -1, sizeof(head));
memset(d, 0, sizeof(d));
memset(Fm, 0, sizeof(Fm));
for(int i = 1; i < n; i++){
scanf("%d %d %d", &u, &v, &c);
add(u, v, c);
}
dfs(1, 0);
Dp(1, 0, 0);
ST();
memset(ans, -1, sizeof(ans));
ans[1] = 0;
for(int i = 0; i < m; i++){
scanf("%d", &u);
int l = 1, r = n + 1;
while(l < r){
int mid = (l + r) >> 1;
if(check(mid) <= u){
l = mid + 1;
}else r = mid;
}
printf("%d\n", l - 1);
}
}
return 0;
}