题意:
给定N个点的树,点有权值,求多少个点对(u,v)满足u到v的路径上点权值最大值减最小值不大于给定的K
思路:
将点对分成经过根的和不经过根的,进行分治
每一次分治都维护点到根的最大值和最小值就可以了,
处理完一个子树的时候,利用容斥+二分查找左右范围就可以了
(路径上两点的较小值是两点到根较小的那个,较大值为两点到根较大的)
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <ctime>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
#define INF 0x3f3f3f3f
#define inf -0x3f3f3f3f
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define mem0(a) memset(a,0,sizeof(a))
#define mem1(a) memset(a,-1,sizeof(a))
#define mem(a, b) memset(a, b, sizeof(a))
#define MP(x,y) make_pair(x,y)
typedef long long ll;
void fre() { freopen("input.in", "r", stdin); freopen("output.out", "w", stdout); }
template <class T1, class T2>inline void gmax(T1 &a, T2 b) { if (b>a)a = b; }
template <class T1, class T2>inline void gmin(T1 &a, T2 b) { if (b<a)a = b; }
typedef pair<int,int>PI;
const int maxn=100100;
const int MAXM=2*maxn;
vector<int>G[maxn];
int dis[maxn],k,size[maxn],f[maxn],Count,root,a[maxn];//Count表示当前子树的结点的总个数
bool Del[maxn];
long long ans=0;
struct node{
int _min,_max;
bool operator < (const node& a) const{
return a._min==_min ? _max<a._max : _min<a._min;
}
}P[maxn];
int C[maxn],cnt;
struct Edge{
int to,next;
}e[MAXM];
int tot,head[maxn];
void init(){
tot=0;
memset(head,-1,sizeof(head));
}
void addedge(int u,int v){
e[tot].to=v;
e[tot].next=head[u];
head[u]=tot++;
}
void getroot(int u,int pre){
size[u]=1,f[u]=0;
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(v!=pre && !Del[v]){
getroot(v,u);
size[u]+=size[v];
f[u]=max(f[u],size[v]);
}
}
f[u]=max(f[u],Count-size[u]);
if(f[u]<f[root]) root=u;
}
void getdep(int u,int pre,int _min,int _max){ //这里还需要重新计算每个子树的size
_min=min(a[u],_min),_max=max(a[u],_max);
if(_min+k>=_max)
P[++cnt]=(node){_min,_max};
size[u]=1;
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(v!=pre && !Del[v]){
getdep(v,u,_min,_max);
size[u]+=size[v];
}
}
}
void add(int x,int add,int m){
while(x<=m)
C[x]+=add,x+=(x&-x);
}
int sum(int x){
int ret=0;
while(x>0)
ret+=C[x],x-=(x&-x);
return ret;
}
long long cal(int u,int _min,int _max){
cnt=0;
getdep(u,0,_min,_max);
sort(P+1,P+cnt+1);
long long ret=0;
for(int i=cnt;i>=1;i--){ //最小值为P[i]._min,最大值大于等于P[i]._min,小于等于P[i]._min+k
int num=lower_bound(P+1,P+i,(node){P[i]._max-k,0})-P; //后面的最大值一定比P[i]._minv大
ret+=(i-num);
}
return ret;
}
void work(int u){
ans+=cal(u,INF,-INF);
Del[u]=true;
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(!Del[v]){
ans-=cal(v,a[u],a[u]);
f[0]=Count=size[v];
getroot(v,root=0);
work(root);
}
}
}
int main(){
int n,_;
scanf("%d",&_);
while(_--){
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
init();
int u,v,w;
for(int i=1;i<n;i++){
scanf("%d%d",&u,&v);
addedge(u,v),addedge(v,u);
}
ans=0;
f[0]=Count=n;
getroot(1,root=0);
memset(Del,false,sizeof(Del));
work(root);
printf("%lld\n",ans*2);
}
return 0;
}