题目链接:http://codeforces.com/problemset/problem/1213/G
题意:
首先给出两个数n,m(1<=n,m<=2e5),表示树的点数和询问树,接着n-1行,每行三个数u,v,w表示树上u到v节点有一条权值为w的边,接下来一行m个询问,每个询问要求输出树上两个节
点 ( u , v ) ( u < v ) (u,v)(u<v) (u,v)(u<v)之间最短距离中最长的一条边的小于 m i m_i mi的对数
思路:
先将询问按照从小到大排序,离线处理,小于第一个询问的肯定也小于第二个询问,接着将所有的边按权值排序,对于每次查询的值 m i m_i mi,将边权值小于等于 m i m_i mi的点用并查集连在一起
那么联通块中的所有点都是满足条件的,通过点数计算贡献,最后计算即可
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <cstring>
#include <cstdio>
#include <vector>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#define dll(x) scanf("%I64d",&x)
#define xll(x) printf("%I64d\n",x)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define mem(X) memset((X), 0, sizeof((X)))
#define memc(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define gg(x) getInt(&x)
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) {ll ans = 1; while (b) {if (b % 2)ans = ans * a % MOD; a = a * a % MOD; b /= 2;} return ans;}
inline void getInt(int* p);
const int inf = 0x3f3f3f3f;
const int N = 2e5 + 10;
ll ans, sum[N], res[N];
int n, m, pre[N];
struct node {
int u, v, w;
} e[N];
struct query {
int x, idx;
} q[N];
ll cal(ll x) {return x * (x - 1) / 2;}
int find(int x) {return pre[x] == x ? x : pre[x] = find(pre[x]);}
void add(int a, int b)
{
int fa = find(a);
int fb = find(b);
if (sum[fa] > sum[fb])swap(fa, fb);
pre[fa] = fb;
ans -= cal(sum[fa]); ans -= cal(sum[fb]);
sum[fb] += sum[fa];
ans += cal(sum[fb]);
}
int main(int argc, char const *argv[])
{
gb;
#ifdef ONLINE_JUDGE
#else
freopen("D:\\test\\in.txt", "r+", stdin);
#endif
cin >> n >> m;
for (int i = 0; i <= n; i++)pre[i] = i, sum[i] = 1;
for (int i = 1; i < n; i++) {
cin >> e[i].u >> e[i].v >> e[i].w;
}
sort(e + 1, e + n, [](node a, node b) {return a.w < b.w;});
for (int i = 1; i <= m; i++) {
cin >> q[i].x;
q[i].idx = i;
}
sort(q + 1, q + 1 + m, [](query a, query b) {return a.x < b.x;});
int now = 1;
for (int i = 1; i <= m; i++) {
int tmp = q[i].x;
while (now < n && e[now].w <= tmp) {
add(e[now].u, e[now].v);
now++;
}
res[q[i].idx] = ans;
}
for (int i = 1; i <= m; i++) {
cout << res[i] << " ";
}
return 0;
}