题目的意思让求树上点对距离小于等于k的数量
每次找到重心然后找有多少符合条件的过重心的点对
然后对于子树递归的做
因为最多是log层所以每个点最多被遍历logn次
最后输出的时候虽然输出了long long但定义ans的时候定的int。。。。wa了好久真是手残
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cstdlib>
#include<algorithm>
using namespace std;
struct rec
{
int num,next,w;
}e[50010];
struct rec2
{
int dist,sum,maxf,fa,bl;
}tree[50010];
int link[50010],q[50010],e_tot;
long long tot_ans;
void tree_clear(int x)
{
tree[x].sum=0;
tree[x].maxf=0;
tree[x].fa=0;
tree[x].dist=0;
tree[x].bl=0;
}
void insert(int x,int y,int w)
{
e_tot++;e[e_tot].num=y;e[e_tot].w=w;e[e_tot].next=link[x];link[x]=e_tot;
}
bool flag[50010];
int n,k;
int tree_core_work(int x)
{
q[1]=x;
int head=0,tail=1;
while (head!=tail)
{
head++;
int now=q[head];
tree[now].sum=1;
tree[now].maxf=0;
for (int p=link[now];p;p=e[p].next)
if (e[p].num!=tree[now].fa&&flag[e[p].num]==false)
{
q[++tail]=e[p].num;
tree[e[p].num].fa=now;
}
}
int ans=200000000,tmp;
for (int i=tail;i;i--)
{
if (ans>max(tree[q[i]].maxf,tail-tree[q[i]].sum))
{
ans=max(tree[x].maxf,tail-tree[q[i]].sum);
tmp=q[i];
}
tree[tree[q[i]].fa].sum+=tree[q[i]].sum;
tree[tree[q[i]].fa].maxf=max(tree[tree[q[i]].fa].maxf,tree[q[i]].sum);
tree_clear(q[i]);
}
return tmp;
}
int tree_dist_work(int x)
{
q[1]=x;
tree[x].bl=x;
int head=0,tail=1;
while (head!=tail)
{
int now=q[++head];
for (int p=link[now];p;p=e[p].next)
if (e[p].num!=tree[now].fa&&flag[e[p].num]==false)
{
q[++tail]=e[p].num;
tree[e[p].num].dist=tree[now].dist+e[p].w;
tree[e[p].num].fa=now;
if (now!=x)
tree[e[p].num].bl=tree[now].bl;
else
tree[e[p].num].bl=e[p].num;
}
}
return tail;
}
int cmp1(int x,int y) {return tree[x].dist<tree[y].dist;}
int cmp2(int x,int y)
{
if (tree[x].bl!=tree[y].bl) return tree[x].bl<tree[y].bl;
return tree[x].dist<tree[y].dist;
}
void tree_ans_work(int x)
{
int len=tree_dist_work(x);
sort(q+1,q+len+1,cmp1);
int head=1,tail=len;
int tmp_sum=0;
while (head<tail)
{
while (tree[q[head]].dist+tree[q[tail]].dist>k&&tail>head) tail--;
tmp_sum+=tail-head;
head++;
}
sort(q+1,q+len+1,cmp2);
head=0,tail=1;
int tmp_head,tmp_tail;
while (head<len)
{
head++;
while (tree[q[head]].bl==tree[q[tail]].bl&&tail<=len) tail++;
tmp_head=head;tmp_tail=tail-1;
while (tmp_head<tmp_tail)
{
while (tree[q[tmp_head]].dist+tree[q[tmp_tail]].dist>k&&tmp_tail>tmp_head) tmp_tail--;
tmp_sum-=tmp_tail-tmp_head;
tmp_head++;
}
head=tail-1;
}
tot_ans+=tmp_sum;
flag[x]=true;
for (int i=1;i<=len;i++)
tree_clear(q[i]);
for (int p=link[x];p;p=e[p].next)
if (!flag[e[p].num])
{
int tmp=tree_core_work(e[p].num);
tree_ans_work(tmp);
}
}
void clear()
{
memset(tree,0,sizeof(tree));
memset(q,0,sizeof(q));
memset(e,0,sizeof(e));
memset(link,0,sizeof(link));
memset(flag,0,sizeof(flag));
tot_ans=0;
e_tot=0;
}
int main()
{
int ci=0;
freopen("test.in","r",stdin);
freopen("test.out","w",stdout);
while (scanf("%d%d",&n,&k))
{
ci++;
if (n==0&&k==0) return 0;
if (ci!=1) printf("\n");
int x,y,w;
clear();
for (int i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&w);
insert(x,y,w);
insert(y,x,w);
}
int tmp=tree_core_work(1);
tree_ans_work(tmp);
printf("%lld",tot_ans);
}
return 0;
}