题目链接:
http://poj.org/problem?id=3162
题意:n个结点构成一棵树 ,MC将在n天,依次按结点编号设为起点,选取距离起点最远的结点作为终点,得到最远距离。问:找到一个区间,使得这个区间里最大最小值的差距不超过m,求区间的最大长度。
解题思路:
求每天的最远距离很明显是树形dp的问题,求的n个值后,我们可以用线段树来存储这些值。每次维护区间的l,r,用线段树求的该区间的最大最小值,然后判断该区间是否符合要求。可以把l,r初始值都设为1,然后往右移动即可。
#include<iostream>
#include<stdio.h>
#include<string.h>
#define inf 0x3f3f3f3f
using namespace std;
const int mod=1e9+7;
const int maxn=1e6+5;
typedef long long ll;
int cnt;
int head[maxn];
struct st{
int to,next,w;
}stm[maxn*2];
void add(int u,int v,int w){
stm[cnt].to=v;
stm[cnt].next=head[u];
stm[cnt].w=w;
head[u]=cnt++;
}
int dis[maxn];
int n,m;
int dp[maxn][3];
void dfs1(int now,int fa){
dp[now][0]=0;//子树方向最远
dp[now][1]=0;//子树方向次远
for(int i=head[now];~i;i=stm[i].next){
int w=stm[i].w;
int to=stm[i].to;
if(to==fa)continue;
dfs1(to,now);
if(dp[to][0]+w>dp[now][0]){
dp[now][1]=dp[now][0];
dp[now][0]=dp[to][0]+w;
}
else if(dp[to][0]+w>dp[now][1]){
dp[now][1]=dp[to][0]+w;
}
}
}
void dfs2(int now,int fa){
for(int i=head[now];~i;i=stm[i].next){
int w=stm[i].w;
int to=stm[i].to;
if(to==fa)continue;
if(dp[now][0]>dp[to][0]+w){
dp[to][2]=max(dp[now][0],dp[now][2])+w;//父节点方向最远
}
else {
dp[to][2]=max(dp[now][1],dp[now][2])+w;
}
dfs2(to,now);
}
}
int tre[maxn*4][2];
void pushup(int rt){
tre[rt][0]=max(tre[rt<<1][0],tre[rt<<1|1][0]);
tre[rt][1]=min(tre[rt<<1][1],tre[rt<<1|1][1]);
}
void build(int l,int r,int rt){
if(l==r){
tre[rt][0]=tre[rt][1]=dis[l];
return ;
}
int mid=(l+r)/2;
build(l,mid,rt*2);
build(mid+1,r,rt*2+1);
pushup(rt);
}
int query0(int lm,int rm,int l,int r,int rt){
if(lm<=l&&rm>=r){
return tre[rt][0];
}
int mid=(l+r)/2;
int ans=0;
if(mid>=lm){
ans=max(ans,query0(lm,rm,l,mid,rt*2));
}
if(mid<rm){
ans=max(ans,query0(lm,rm,mid+1,r,rt*2+1));
}
return ans;
}
int query1(int lm,int rm,int l,int r,int rt){
if(lm<=l&&rm>=r){
return tre[rt][1];
}
int mid=(l+r)/2;
int ans=inf;
if(mid>=lm){
ans=min(ans,query1(lm,rm,l,mid,rt*2));
}
if(mid<rm){
ans=min(ans,query1(lm,rm,mid+1,r,rt*2+1));
}
return ans;
}
int main(){
int u,w;
cnt=0;
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
for(int i=2;i<=n;i++){
scanf("%d%d",&u,&w);
add(i,u,w);
add(u,i,w);
}
dfs1(1,-1);
dfs2(1,-1);
for(int i=1;i<=n;i++){
dis[i]=max(dp[i][0],dp[i][2]);
}
build(1,n,1);
int ans=0;
int l,r;
l=r=1;
while((l+r)<n*2){
int maxs=query0(l,r,1,n,1);
int mins=query1(l,r,1,n,1);
/* cout<<l<<" "<<r<<endl;
cout<<endl;
cout<<maxs<<" "<<mins<<endl;*/
if(mins+m>=maxs){
ans=max(ans,r-l+1);
r++;
if(r>n)break;
}
else{
l++;
}
}
cout<<ans;
return 0;
}