题目大意:
给你一棵树,N个点(编号1~N),N-1条边,边上有权值。首先要知道每个顶点的所能到达的最远距离,然后有M个询问,每个询问给定一个Q,求出最多的编号连续的顶点,要求它们的最远距离的最大差值小于等于Q。数据范围:N<=50000,M<=500。
算法实现:
对于第一个问题,求出每个顶点所能到达的最远距离,是一个非常经典的问题。由于N最多有5W,我们不能对每个点都进行一次dfs。正确的解法是两次dfs,利用树形DP解决。每个顶点的最远距离可以由与它相邻的顶点的最远距离转移得到,但是也许与它相邻的顶点的最远距离恰好经过该顶点,因此我们还需要记录最远距离来自哪里以及次远距离。对于某个顶点,第一次dfs计算的是以该顶点为根,往下走的最远距离和次远距离。第二次dfs需从上往下更新,对于某个顶点,如果它的父亲结点的最远路径经过它,那么它往上走的最远距离等于它的父亲的次远距离加上它与它的父亲之间的边权;否则,它往上走的最远距离等于它的父亲的最远距离加上它与它的父亲之间的边权。往上走的最远距离以及往下走的最远距离的最大值就是某个顶点的最远距离。
第二个问题可以用RMQ的ST算法解决。ST算法的时间复杂度是预处理O(NlogN),查询O(1)。对于这道题目,由于要求出最多的编号连续的顶点,因此对于每次询问都要扫一遍最远距离数组,也就是每次询问要处理N个区间,总的时间复杂度是O(NlogN)+O(NM)。虽然ST算法中每次查询都是O(1),但是还是需要算一条表达式,其中包括了log函数,这个函数运行较慢,因此我们可以先把50000以内的log先预处理出来。
参考代码:
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#define MAXN 50010
#define cl(x,y) (memset(x,y,sizeof(x)))
using namespace std;
struct Edge
{
int u,v,w,next;
}e[MAXN*2];
int fa[MAXN],head[MAXN],tot;
int far1[MAXN],far2[MAXN],num[MAXN];//far1表示最远距离,far2表示次远距离,num表示最远路径的来自哪个顶点
int maxv[MAXN][16],minv[MAXN][16];
int pow2[16]={1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768};
double ln[MAXN];
void init()
{
for(int i=1;i<MAXN;i++)
ln[i]=log((double)i);
}
void ST(int N)
{
for(int i=1;i<=N;i++)
maxv[i][0]=minv[i][0]=far1[i];
for(int j=1;j<16;j++)
for(int i=1;i+pow2[j]-1<=N;i++)//这里不能i<=N,否则可能RE
{
maxv[i][j]=max(maxv[i][j-1],maxv[i+pow2[j-1]][j-1]);
minv[i][j]=min(minv[i][j-1],minv[i+pow2[j-1]][j-1]);
}
}
int getmax(int l,int r)
{
int k=ln[r-l+1]/ln[2];
return max(maxv[l][k],maxv[r-pow2[k]+1][k]);
}
int getmin(int l,int r)
{
int k=ln[r-l+1]/ln[2];
return min(minv[l][k],minv[r-pow2[k]+1][k]);
}
void add(int u,int v,int w)
{
e[tot].u=u;
e[tot].v=v;
e[tot].w=w;
e[tot].next=head[u];
head[u]=tot++;
}
void dfs1(int u)
{
far1[u]=far2[u]=0;
num[u]=-1;
for(int i=head[u];i!=-1;i=e[i].next)
{
int v=e[i].v;
if(v!=fa[u])
{
fa[v]=u;
dfs1(v);
if(far1[v]+e[i].w>far1[u])
{
far2[u]=far1[u];
far1[u]=far1[v]+e[i].w;
num[u]=v;
}
else if(far1[v]+e[i].w>far2[u])
far2[u]=far1[v]+e[i].w;
}
}
}
void dfs2(int u,int w)
{
if(fa[u]!=-1)
{
int tmp;
if(num[fa[u]]==u)
tmp=far2[fa[u]]+w;
else
tmp=far1[fa[u]]+w;
if(tmp>far1[u])
{
far2[u]=far1[u];
far1[u]=tmp;
num[u]=fa[u];
}
else if(tmp>far2[u])
far2[u]=tmp;
}
for(int i=head[u];i!=-1;i=e[i].next)
{
int v=e[i].v;
if(v!=fa[u])
dfs2(v,e[i].w);
}
}
int main()
{
int N,M;
init();
while(~scanf("%d%d",&N,&M))
{
if(N==0&&M==0)
break;
tot=0;
cl(head,-1);
for(int i=1;i<N;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
fa[1]=-1;
dfs1(1);
dfs2(1,-1);
ST(N);
while(M--)
{
int Q,l=1,r=2,ans=1;
scanf("%d",&Q);
while(r<=N)
{
if(getmax(l,r)-getmin(l,r)<=Q)
ans++;
else
l++;
r++;
}
printf("%d\n",ans);
}
}
return 0;
}