难在题意(代码5分钟读题两小时笑死)
题意:n个点构成的一棵树,求出树上每个点能到达的最远距离,这样就有n个距离,然后从这n个距离中找出最长的区间,使得区间内的最大值 − - −最小值 <= m
题解:最远距离就是求两遍 d f s dfs dfs,一遍求 x x x节点往儿子方向走的最远距离,一遍求往父亲方向的最远距离。直接树形dp搞一下就好了。。
然后求区间维护一个单调递增和一个单调递减的单调队列。
设置两个指针
i
=
1
,
j
=
1
i=1,j=1
i=1,j=1,如果单调队列中最大值减去最小值满足条件就不需要操作,让
j
j
j指针自增就好了;否则就需要更新答案并且更新左区间
i
i
i
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<stdio.h>
#include<string.h>
#include<queue>
#include<cmath>
#include<map>
#include<set>
#include<vector>
using namespace std;
#define inf 0x3f3f3f3f
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define mem(a,b) memset(a,b,sizeof(a));
#define lowbit(x) x&-x;
#define debugint(name,x) printf("%s: %d\n",name,x);
#define debugstring(name,x) printf("%s: %s\n",name,x);
typedef long long ll;
typedef unsigned long long ull;
const double eps = 1e-6;
const int maxn = 1e6+5;
const int mod = 1e9+7;
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
struct node{
int to,nxt,w;
}e[maxn*2];
int head[maxn],tot;
void add(int u,int v,int w){
e[tot].to = v;
e[tot].nxt = head[u];
e[tot].w = w;
head[u] = tot++;
}
int dp1[maxn];
int dp2[maxn];
int pre[maxn];
int dis[maxn];
void dfs1(int u,int fa){
dp1[u] = 0;
pre[u] = fa;
for(int i = head[u]; ~i; i = e[i].nxt){
int v = e[i].to;
if(v == fa) continue;
dfs1(v,u);
dp1[u] = max(dp1[u],dp1[v]+e[i].w);
}
}
void dfs2(int u,int fa){
dp2[u] = dp2[fa];
int sum;
for(int i = head[fa]; ~i; i = e[i].nxt){
int v = e[i].to;
int w = e[i].w;
if(v == pre[fa]) continue;
if(v == u) sum = w;
else{
dp2[u] = max(dp2[u],dp1[v]+w);
}
}
dp2[u] += sum;
for(int i = head[u]; ~i; i = e[i].nxt){
int v = e[i].to;
if(v == fa) continue;
dfs2(v,u);
}
}
int qmax[maxn],qmin[maxn];
int main() {
int n,m,x,y;
scanf("%d%d",&n,&m);
mem(head,-1);
tot = 0;
for(int i=2; i<=n; i++){
scanf("%d%d",&x,&y);
add(i,x,y);
add(x,i,y);
}
dfs1(1,0);
dfs2(1,0);
for(int i = 1; i <= n; i++){
dis[i] = max(dp1[i],dp2[i]);
}
int f1,r1;
int f2,r2;
int ans = 0;
f1 = f2 = r1 = r2 = 0;
int i,j;
for(i = 1,j = 1; j <= n; j++){
while(f1 < r1 && dis[qmax[r1-1]] <= dis[j]) r1--;
qmax[r1++] = j;
while(f2 < r2 && dis[qmin[r2-1]] >= dis[j]) r2--;
qmin[r2++] = j;
if(dis[qmax[f1]]-dis[qmin[f2]] > m){
ans = max(ans,j-i);
while(dis[qmax[f1]]-dis[qmin[f2]] > m){
i = min(qmax[f1],qmin[f2])+1;
while(f1 < r1 && qmax[f1] < i) f1++;
while(f2 < r2 && qmin[f2] < i) f2++;
}
}
}
ans = max(ans,j-i);
printf("%d\n",ans);
}