You are given a weighted tree consisting of nn vertices. Recall that a tree is a connected graph without cycles. Vertices uiui and vivi are connected by an edge with weight wiwi.
You are given mm queries. The ii-th query is given as an integer qiqi. In this query you need to calculate the number of pairs of vertices (u,v)(u,v)(u<vu<v) such that the maximum weight of an edge on a simple path between uu and vv doesn't exceed qiqi.
Input
The first line of the input contains two integers nn and mm (1≤n,m≤2⋅1051≤n,m≤2⋅105) — the number of vertices in the tree and the number of queries.
Each of the next n−1n−1 lines describes an edge of the tree. Edge ii is denoted by three integers uiui, vivi and wiwi — the labels of vertices it connects (1≤ui,vi≤n1≤ui,vi≤n, ui≠viui≠vi) and the weight of the edge (1≤wi≤2⋅1051≤wi≤2⋅105). It is guaranteed that the given edges form a tree.
The last line of the input contains mm integers q1,q2,…,qmq1,q2,…,qm (1≤qi≤2⋅1051≤qi≤2⋅105), where qiqi is the maximum weight of an edge in the ii-th query.
Output
Print mm integers — the answers to the queries. The ii-th value should be equal to the number of pairs of vertices (u,v)(u,v) (u<vu<v) such that the maximum weight of an edge on a simple path between uu and vv doesn't exceed qiqi.
Queries are numbered from 11 to mm in the order of the input.
Examples
input
Copy
7 5 1 2 1 3 2 3 2 4 1 4 5 2 5 7 4 3 6 2 5 2 3 4 1
output
Copy
21 7 15 21 3
input
Copy
1 2 1 2
output
Copy
0 0
input
Copy
3 3 1 2 1 2 3 2 1 3 2
output
Copy
1 3 3
Note
The picture shows the tree from the first example:
题意: 给你一棵树 然后树的边有权值 每次询问树的路径中最大的权值不超过k的路径数量
思路:如果我们一开始就把树建好 然后每次都去dfs的话 那么肯定是会超时的 那么就考虑将询问离线!!!(常用操作) 对于每一个离线的询问的值W 把小于W的边都加入进来 这个时候是用并查集维护联通快的个数 然后答案就是每个联通快中初始点的个数n * (n - 1) / 2;初始值为0 每次加边改变答案
当两个联通快合并时 需要减去原先两个联通块的贡献 然后在加上新合并联通块的贡献 然后记录答案
代码如下:
#include<bits/stdc++.h>
#define ll long long
#define N 200010
using namespace std;
int par[N];
ll sum[N],ans[N],temp=0;
struct A{
int u,v,w;
}e[N];
struct B{
int l,r;
}g[N];
bool cmp(A x,A y){
return x.w<y.w;
}
bool cmp1(B x,B y){
return x.l<y.l;
}
void init(int n){
for(int i=0;i<=n;i++){
par[i]=i;
sum[i]=1;
}
}
int fi(int x){
if(x!=par[x]){
par[x]=fi(par[x]);
}
return par[x];
}
void un(int x,int y){
x=fi(x);
y=fi(y);
if(x!=y){
temp-=(sum[x]*(sum[x]-1)/2);
temp-=(sum[y]*(sum[y]-1)/2);
par[x]=y;
sum[y]+=sum[x];
temp+=(sum[y]*(sum[y]-1)/2);
}
}
int main(){
int n,m;
scanf("%d %d",&n,&m);
init(n);
for(int i=1;i<n;i++){
scanf("%d %d %d",&e[i].u,&e[i].v,&e[i].w);
}
for(int i=0;i<m;i++){
scanf("%d",&g[i].l);
g[i].r=i;
}
sort(e+1,e+n,cmp);
sort(g,g+m,cmp1);
int len=1;
for(int i=0;i<m;i++){
while(g[i].l>=e[len].w&&len<n){
un(e[len].u,e[len].v);
len++;
}
ans[g[i].r]=temp;
}
for(int i=0;i<m;i++){
printf("%I64d ",ans[i]);
}
return 0;
}