并查集+离线处理
代码:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
#define maxm 100002
#define maxn 20002
#define maxq 5002
int set[maxn];
int size[maxn];//结点数
struct que
{
int q, i;
long long ans;
}ques[maxq];
bool cmpp1(que a,que b)
{
return a.q < b.q;
}
bool cmpp2(que a, que b)
{
return a.i < b.i;
}
struct edge
{
int v, i, j;
void init(int a, int b, int x)
{
if (a < b)
{
i = a; j = b;
}
else
{
i = b; j = a;
}
v = x;
}
};
edge a1[maxm], a2[maxm];
bool cmp1(edge a, edge b)
{
if (a.i < b.i)
return true;
else if (a.i == b.i&&a.j < b.j)
return true;
else
return false;
}
bool cmp2(edge a, edge b)
{
return a.v < b.v;
}
int upperbound(edge *A, int x, int y, int v)
{
int m;
while (x < y)
{
m = x + (y - x) / 2;
if (A[m].v <= v)
x = m + 1;
else
y = m;
}
return x;
}
void init_set()
{
for (int i = 0; i<maxn; ++i)
{
set[i] = i;
size[i] = 1;
}
}
int findSet(int x)//路径压缩
{
if (x == set[x])
return x;
else
return set[x] = findSet(set[x]);
}
void unionSet(int x, int y)//启发式合并
{
int fx = findSet(x);
int fy = findSet(y);
if (fy == fx)
return;
if (size[fx] >= size[fy])
{
size[fx] += size[fy];
set[fy] = fx;
}
else
{
size[fy] += size[fx];
set[fx] = fy;
}
}
int main()
{
//freopen("input.txt", "r", stdin);
int T;
scanf("%d", &T);
while (T--)
{
int n, m, q, a, b, c;
long long ans = 0, t;
scanf("%d%d%d", &n, &m, &q);
for (int i = 0; i < m; ++i)
{
scanf("%d%d%d", &a, &b, &c);
a1[i].init(a, b, c);
}
sort(a1, a1 + m, cmp1);
a = 0, b = 0;
int m2 = 0;
for (int i = 0; i < m; ++i)
{
if (a1[i].i == a&&a1[i].j == b)
a2[m2 - 1].v = min(a1[i].v, a2[m2 - 1].v);
else
{
a2[m2++] = a1[i];
a = a1[i].i, b = a1[i].j;
}
}
sort(a2, a2 + m2, cmp2);
for (int i = 0; i < q; ++i)
{
scanf("%d", &ques[i].q);
ques[i].i = i;
ques[i].ans = 0;
}
sort(ques, ques + q, cmpp1);
init_set(); a = 0;
for (int i = 0; i < q; ++i)
{
b = upperbound(a2, 0, m2, ques[i].q);
//printf("b %d\n", b);
for (int j = a; j < b; ++j)
{
unionSet(a2[j].i, a2[j].j);
}
for(int j = 1; j <= n; ++j)
{
if (set[j] == j&&size[j] > 1)
{
t = size[j];
ques[i].ans += t*(t - 1);
//printf("t %lld\n", t);
//printf("ans %lld\n", ques[i].ans);
}
}
a = b;
}
sort(ques, ques + q, cmpp2);
for (int i = 0; i < q; ++i)
printf("%lld\n", ques[i].ans);
}
//system("pause");
//while (1);
return 0;
}