题意:
给出一颗n个结点的树,点上有权值;
求点对(x,y)满足x!=y且x到y的路径上最大值与最小值的差<=D;
n<=100000,多组数据,所有数据n的总和<=500000;
题解:
来填一填当年挖下的坑;
这个数据范围真是恶意。。直接说五组数据不好吗!
考虑这题怎么做,在这场考试那天的前一天,我学习了树分治算法;
然后他就出了,然后我就写了,然后我就写不出来了;
当年的我实在naive;
我翻出了当时交上去的代码,改了好久好久。。
首先这道题的思路比较容易,最直观的就是树的点分治嘛;
分治之后统计答案,首先搜出从当前根出发到所有点的最大权值与最小权值;
然后为了统计数量,我们先按最大权值从小到大排个序;
那么如果就让这个最大权值为路上的最大权值,那另一个端点一定排序在它前面;
所以一边遍历一边将点的最小值插入树状数组里;
每次查询[0,点到根路径上最大值-D]这个区间的值就好了;
复杂度O(nlog^2n),别人还有更加优越的log做法;
嘴巴很好A对吧,然而我当时真是没救;
我居然每次求出重心之后,没有用重心去分治!
这两个代码的区别就差一个字母啊!然后复杂度不对调了我半天。。
代码:
#pragma comment(linker, "/STACK:102400000,102400000")
#include<ctype.h>
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define N 100100
using namespace std;
typedef long long ll;
struct pr
{
int ma,mi;
friend bool operator <(pr a,pr b)
{
return a.ma<b.ma;
}
}popo[N];
int to[N<<1],nex[N<<1],head[N],val[N];
int size[N],dis[N],len;
int D,tot,mi,G,bk,cnt;
ll ans;
int sum[N];
bool ban[N];
void init()
{
memset(head,0,sizeof(head));
memset(ban,0,sizeof(ban));
tot=0,ans=0;
}
inline char getc()
{
static char buf[1<<15],*S,*T;
if(S==T)
{
T=(S=buf)+fread(buf,1,1<<15,stdin);
if(S==T) return EOF;
}
return *S++;
}
inline int read()
{
static char ch;
static int D;
while(!isdigit(ch=getc()));
for(D=ch-'0';isdigit(ch=getc());)
D=D*10+(ch-'0');
return D;
}
inline int lowbit(int x)
{
return x&(-x);
}
inline int lb(int x)
{
int l=1,r=len,mid;
while(l<=r)
{
mid=l+r>>1;
if(dis[mid]<x)
l=mid+1;
else
r=mid-1;
}
return l;
}
void update(int x,int val)
{
while(x<=len)
{
sum[x]+=val;
x+=lowbit(x);
}
}
int query(int x)
{
if(x<=0) return 0;
int ret=0;
while(x)
{
ret+=sum[x];
x-=lowbit(x);
}
return ret;
}
void add(int x,int y)
{
to[++tot]=y;
nex[tot]=head[x];
head[x]=tot;
}
void get_G(int x,int pre)
{
size[x]=1;
int i,y,temp=0;
for(i=head[x];i;i=nex[i])
{
if(!ban[y=to[i]]&&y!=pre)
{
get_G(y,x);
size[x]+=size[y];
temp=max(temp,size[y]);
}
}
temp=max(temp,bk-size[x]);
if(temp<mi)
mi=temp,G=x;
}
void dfs(int x,int pre,int ma,int mi)
{
if(ma-mi>D) return ;
ma=max(ma,val[x]);
mi=min(mi,val[x]);
static pr p;
p.ma=ma,p.mi=mi;
popo[++cnt]=p;
int i,y;
for(i=head[x];i;i=nex[i])
{
if(!ban[y=to[i]]&&y!=pre)
{
dfs(y,x,ma,mi);
}
}
}
ll calc(int x)
{
static pr p;
p.ma=p.mi=val[x];
popo[cnt=1]=p;
int i,j,y,last;
ll ret=0;
for(j=head[x];j;j=nex[j])
{
if(!ban[y=to[j]])
{
last=cnt;
dfs(y,x,val[x],val[x]);
sort(popo+last+1,popo+cnt+1);
for(i=last+1;i<=cnt;i++)
{
if(popo[i].ma-popo[i].mi<=D)
ret+=i-last-1-query(lb(popo[i].ma-D)-1);
update(lb(popo[i].mi),1);
}
for(i=last+1;i<=cnt;i++)
update(lb(popo[i].mi),-1);
}
}
ret=-ret;
sort(popo+1,popo+cnt+1);
for(i=1;i<=cnt;i++)
{
if(popo[i].ma-popo[i].mi<=D)
ret+=i-1-query(lb(popo[i].ma-D)-1);
update(lb(popo[i].mi),1);
}
for(i=1;i<=cnt;i++)
update(lb(popo[i].mi),-1);
return ret;
}
void slove(int x)
{
ban[x]=1;
if(bk==1) {ban[x]=0;return ;}
int i,y;
for(i=head[x];i;i=nex[i])
{
if(!ban[y=to[i]])
{
bk=size[y];
mi=0x3f3f3f3f;
get_G(y,x);
slove(G);
}
}
ans+=calc(x);
ban[x]=0;
}
int main()
{
// freopen("tt.in","r",stdin);
int c,T,n,m,i,j,k,x,y;
T=read();
for(c=1;c<=T;c++)
{
init();
n=read(),D=read();
for(i=1;i<=n;i++)
{
val[i]=read();
dis[i]=val[i];
}
sort(dis+1,dis+n+1);
len=unique(dis+1,dis+n+1)-dis-1;
for(i=1;i<n;i++)
{
x=read(),y=read();
add(x,y),add(y,x);
}
bk=n;
mi=0x3f3f3f3f;
get_G(1,0);
slove(G);
printf("%lld\n",ans<<1);
}
return 0;
}