题意:给定一颗生成树,编号为1-n,每个顶点可以作为起点走最长的一段距离(未知)对于每个询问q,找到最长的一段连续编号的顶点使得编号中的最长距离的最大值与最小值的差小于q。
思路:首先用树形dp二次扫描换根法求出每个顶点作为起点能走的最长距离,用RMQ算法预处理一下就能很快查询区间内最大值和最小值。(预处理出log数组求logn,自带的log函数太慢了会超时)。最后对于每个查询使用取尺算法,就可以O(n)算出答案。
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 5e4 + 50;
struct node
{
int v, w;
node(int a, int b)
{
v = a, w = b;
}
};
int n, m;
vector<node>E[maxn];
void addE(int u, int v, int w)
{
E[u].push_back(node(v, w));
}
int far[2][maxn], dp[3][maxn]; //far0最远,far1次远,dp0最长子树,1次长子树,2父亲
void Dp(int x, int fa)
{
int len = E[x].size();
for(int i = 0; i < len; i++)
{
int y = E[x][i].v, w = E[x][i].w;
if(y == fa) continue;
Dp(y, x);
if(dp[0][x] < dp[0][y] + w)
{
dp[1][x] = dp[0][x];
dp[0][x] = dp[0][y] + w;
far[1][x] = far[0][x];
far[0][x] = y;
}
else
if(dp[1][x] < dp[0][y] + w)
{
dp[1][x] = dp[0][y] + w;
far[1][x] = y;
}
}
}
void dfs(int x, int fa)
{
int len = E[x].size();
for(int i = 0; i < len; i++)
{
int y = E[x][i].v, w = E[x][i].w;
if(y == fa) continue;
if(y != far[0][x])
dp[2][y] = max(dp[0][x], dp[2][x]) + w;
else
dp[2][y] = max(dp[1][x], dp[2][x]) + w;
dfs(y, x);
}
}
int maxx[maxn],mm[maxn];
int Min[maxn][50], Max[maxn][50];
void build()
{
mm[0] = -1;
for(int i = 1; i <= n; i++) {
mm[i] = ((i&(i-1))==0)?mm[i-1]+1:mm[i-1];
Min[i][0] = Max[i][0] = maxx[i];
}
for(int j = 1; j <= 20; j++)
{
for(int i = 1; i + (1 << j) - 1 <= n; i++)
{
Min[i][j] = min(Min[i][j - 1], Min[i + (1 << (j - 1))][j - 1]);
Max[i][j] = max(Max[i][j - 1], Max[i + (1 << (j - 1))][j - 1]);
}
}
}
int cxMin(int l, int r)
{
int lo = mm[r - l + 1];
return min(Min[l][lo], Min[r - (1 << lo) + 1][lo]);
}
int cxMax(int l, int r)
{
int lo = mm[r - l + 1];
return max(Max[l][lo], Max[r - (1 << lo) + 1][lo]);
}
int main()
{
// freopen("in.txt", "r", stdin);
// ios::sync_with_stdio(false);
while(~scanf("%d%d", &n, &m))
{
if(n == 0 && m == 0) break;
int a, b, w;
for(int i = 1; i <= n; i++) E[i].clear();
for(int i = 1; i < n; i++)
{
scanf("%d%d%d", &a, &b, &w);
addE(a, b, w);
addE(b, a, w);
}
for(int i = 0; i < 3; i++)for(int j = 1; j <= n; j++) dp[i][j] = 0;
Dp(1, -1); dfs(1, -1);
for(int i = 1; i <= n; i++)
maxx[i] = max(dp[0][i], dp[2][i]);
// for(int i = 1; i <= n; i++)cout << maxx[i] << endl;
build();
while(m--)
{
int M; scanf("%d", &M);
int ans = 1;
int l = 1, r = 1;
int d = 0;
while(1)
{
while(r <= n && d <= M)
{
r++;
d = cxMax(l, r) - cxMin(l, r);
}
ans = max(ans, r - l);
if(r >= n || l >= n) break;
l++;
d = cxMax(l, r) - cxMin(l, r);
}
printf("%d\n", ans);
}
}
return 0;
}