>Link
ybtoj最短距离
>Description
n
≤
1
0
5
,
m
≤
2
∗
1
0
5
n\le 10^5,m\le 2*10^5
n≤105,m≤2∗105
>解题思路
得出一些性质:一个白点到距离它最近的黑点,路上其它点都是白点,而且这个黑点也是这些白点距离最近的黑点
跑出来大概就是一个多条链拼在一起,结束点为黑点的图。发现这样很不好处理
我们建一个点
T
T
T,所有黑点与
T
T
T 连一条边权为 0 的边,这样最终得到的图就是一棵以
T
T
T 为根的树(所有白点最终到达
T
T
T)
从
T
T
T 出发跑单源最短路,找出每个点的最短路线,
T
T
T 到某个白点的最短路径一定是这个白点距离它最近的黑点的距离,最后用这些经过的边做一个最小生成树就行了
根据上面的性质,最终得出的图一定是树,因为一个白点“选定”了它到达的黑点,它只会从一条链过去,不会浪费去建其它路径
具体实现的话就跑最短路的时候,记录这个点从哪些点过来路程最短,最后把这些边都放过去排序
>代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
#include <map>
#define N 600010
#define LL long long
using namespace std;
struct node
{
int fr; LL w;
};
vector<node> f[N];
queue<int> Q;
map<int, bool> hav[N];
struct edge
{
int to, nxt; LL w;
} e[N];
struct EDGE
{
int x, y; LL w;
} a[N];
int n, m, cnt, h[N], col[N], fa[N], tot;
bool vis[N];
LL ans, dis[N];
void add (int u, int v, LL w)
{
e[++cnt] = (edge){v, h[u], w}; h[u] = cnt;
e[++cnt] = (edge){u, h[v], w}; h[v] = cnt;
}
bool cmp (EDGE aa, EDGE bb) {return aa.w < bb.w;}
int find (int now)
{
if (now == fa[now]) return now;
return fa[now] = find (fa[now]);
}
void spfa ()
{
memset (dis, 0x7f, sizeof (dis));
Q.push (n + 1), vis[n + 1] = 1, dis[n + 1] = 0;
int u, v, len;
while (!Q.empty())
{
u = Q.front();
Q.pop(); vis[u] = 0;
for (int i = h[u]; i; i = e[i].nxt)
{
v = e[i].to;
if (dis[u] + e[i].w == dis[v])
{
if (!hav[v][u]) //判断是否重复记录了
{
f[v].push_back ((node){u, e[i].w});
hav[v][u] = 1;
}
}
else if (dis[u] + e[i].w < dis[v])
{
len = f[v].size();
for (int j = 0; j < len; j++)
hav[v][f[v][j].fr] = 0;
f[v].clear();
f[v].push_back ((node){u, e[i].w});
hav[v][u] = 1;
dis[v] = dis[u] + e[i].w;
if (!vis[v])
Q.push (v), vis[v] = 1;
}
}
}
for (int i = 1; i <= n; i++)
{
len = f[i].size();
for (int j = 0; j < len; j++)
a[++tot] = (EDGE){i, f[i][j].fr, f[i][j].w};
}
}
int main()
{
freopen ("minimum.in", "r", stdin);
freopen ("minimum.out", "w", stdout);
scanf ("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf ("%d", &col[i]);
int u, v; LL w;
for (int i = 1; i <= m; i++)
{
scanf ("%d%d%lld", &u, &v, &w);
add (u, v, w);
}
for (int i = 1; i <= n; i++)
if (col[i]) add (i, n + 1, 0);
spfa ();
sort (a + 1, a + 1 + tot, cmp);
for (int i = 1; i <= n + 1; i++) fa[i] = i;
int fx, fy;
for (int i = 1; i <= tot; i++)
{
fx = find (a[i].x), fy = find (a[i].y);
if (fx == fy) continue;
fa[fx] = fy;
ans += a[i].w;
}
for (int i = 1; i <= n; i++)
if (find (i) != find (n + 1))
{
printf ("impossible");
return 0;
}
printf ("%lld", ans);
return 0;
}