给出一个图, 其中一些特殊点只能连一条边, 问最小生成树的权值.
我们可以发现, 生成树上一个点如果只能连一条边, 那么他肯定是叶子结点.
有两个思路:
用非特殊点跑最小生成树, 最后再把特殊点接上就行了. 这是赛后听到的思路.
直接跑prim, 但是特殊点不用来更新其他点的距离(dis).
理论上第二种也挺简单的, 然而赛中疯狂wa…因为判断因为是特殊点不再更新和因为是走过不再更新两个的顺序出了问题. 本质上还是对prim的理解不够深.
首先我们回想一个问题: 为什么prim要在队列过程中计算答案.
while (!q.empty()) {
qnode t = q.top();
int u=t.pos;
q.pop();
if (vis[u] == 1)
continue;
vis[u] = 1;
ans+=t.ds;// 为什么要在这个位置计算答案? 在最后计算不行吗?
for (int i = head[u]; i != -1; i = edge[i].nxt) {
Edge&e = edge[i];
if (dis[e.v]>e.w){ //Dijkstra只是这里不同
dis[e.v] = e.w;
q.push(qnode{ e.v,dis[e.v] });
}
}
}
return ans;
实际上在最后dis求和这样是不行的. 因为在第一次到x点的时候, dis[x]是到达这个点最短边权的长度, 然而在多次到达这个点的时候, 这个值会不断的被更新, 到后面时, dis数组跟本不能组成那个所求的树.
从这个角度, 我们可以发现dis是记录当前把这个点取到生成树中的代价.
回到这个题, 如果我们认为这个点是特殊点的话, 那么就不从这个点向下更新. 代码大概是这样的:
while (!q.empty()) {
qnode t = q.top();
int u = t.pos;
q.pop();
if (vis[u] == 1)continue;
vis[u] = 1;
ss.insert(u);
ans += t.ds;
if(p[u]==1)continue;
for (int i = head[u]; i != -1; i = edge[i].nxt) {
Edge e = edge[i];
if (dis[e.v] > e.w) {
dis[e.v] = e.w;
qnode tmp1;
tmp1.pos = e.v, tmp1.ds = dis[e.v];
q.push(tmp1);
}
}
}
但这个方式会有问题, 像这样的图:
就会出问题. 因为对于特殊点的操作是阉割版的, 所以虽然对于特殊点后面的边取法不会出问题, 但是对于取这个点的边并不是最优的. 取了连到特殊点x的第一条边d1, 就把x的vis设为了1, 从此就认为x已经是最优的了.
所以只需要到特殊点的话, 特判更新其边的最优贡献就好了.
struct Edge {int u, v, nxt;ll w;} edge[M];
int head[N], ei;
struct qnode {
int pos;
ll ds;
bool operator<(const qnode a) const {
return ds > a.ds;
}
};
int n, m, P;
ll dis[N];
int vis[N], p[N];
void addedge(int u, int v, ll w) {
edge[ei].u = u, edge[ei].v = v, edge[ei].w = w;
edge[ei].nxt = head[u];
head[u] = ei++;
}
void init() {
memset(head, -1, sizeof(head));
}
int pre[M];
ll prim(int s) {
ll ans = 0;
for (int i = 1; i <= n; i++) {
dis[i] = INF;
vis[i] = 0;
}
dis[s] = 0;
priority_queue<qnode> q;
set<int> ss;
fill(pre,pre+M,inf);
qnode tmp;
tmp.pos = s, tmp.ds = dis[s];
q.push(tmp);
while (!q.empty()) {
qnode t = q.top();
int u = t.pos;
q.pop();
if (p[u] == 1) {
if(pre[u]==inf)ss.insert(u);
ans-=(pre[u]==inf)?0:pre[u];
checkMin(pre[u],t.ds);
ans+=pre[u];
continue;
}
if (vis[u] == 1)continue;
vis[u] = 1;
ss.insert(u);
ans += t.ds;
for (int i = head[u]; i != -1; i = edge[i].nxt) {
Edge e = edge[i];
if (dis[e.v] > e.w) {
dis[e.v] = e.w;
qnode tmp1;
tmp1.pos = e.v, tmp1.ds = dis[e.v];
q.push(tmp1);
}
}
}
if (ss.size() != n)return INF;
return ans;
}
ll minn(ll a, ll b) {
if (a > b)return b;
return a;
}
int main() {
int cnt = 0;
ll WW = INF;
scanf("%d%d%d", &n, &m, &P);
for (int i = 1; i <= P; i++) {
int tmp;
scanf("%d", &tmp);
p[tmp] = 1;
}
init();
for (int i = 1; i <= m; i++) {
int u, v;
ll w;
scanf("%d%d%lld", &u, &v, &w);
if (u == v)continue;
addedge(u, v, w);
addedge(v, u, w);
if (u == v)cnt++;
if (u != v)WW = minn(WW, w);
}
int s = -1;
for (int i = 1; i <= n; i++) {
if (p[i] == 0) {
s = i;
break;
}
}
if (n == 2) {
printf("%lld\n", WW);
return 0;
}
if (s == -1) {
printf("-1\n");
return 0;
}
ll anss = prim(s);
printf("%lld\n", ((anss >= INF) ? -1 : anss));
return 0;
}
场内没想到可以有Ⅰ的写法: 将非特殊点单独跑最小生成树, 再把特殊点接上去:
int n, m, P;
ll dis[N];
int vis[N], p[N];
void addedge(int u, int v, ll w) {
edge[ei].u = u, edge[ei].v = v, edge[ei].w = w;
edge[ei].nxt = head[u];
head[u] = ei++;
}
void init() {
memset(head, -1, sizeof(head));
}
int pre[M];
ll prim(int s) {
ll ans = 0;
for (int i = 1; i <= n; i++) {
dis[i] = INF;
vis[i] = 0;
}
dis[s] = 0;
priority_queue<qnode> q;
set<int> ss;
fill(pre,pre+M,inf);
qnode tmp;
tmp.pos = s, tmp.ds = dis[s];
q.push(tmp);
while (!q.empty()) {
qnode t = q.top();
int u = t.pos;
q.pop();
if (p[u] == 1) {
if(pre[u]==inf)ss.insert(u);
ans-=(pre[u]==inf)?0:pre[u];
checkMin(pre[u],t.ds);
ans+=pre[u];
continue;
}
if (vis[u] == 1)continue;
vis[u] = 1;
ss.insert(u);
ans += t.ds;
for (int i = head[u]; i != -1; i = edge[i].nxt) {
Edge e = edge[i];
if (dis[e.v] > e.w) {
dis[e.v] = e.w;
qnode tmp1;
tmp1.pos = e.v, tmp1.ds = dis[e.v];
q.push(tmp1);
}
}
}
if (ss.size() != n)return INF;
ll ans2 = 0;
for (int i = 1; i <= n; i++) {
ans2 += dis[i];
}
return ans;
}
ll minn(ll a, ll b) {
if (a > b)return b;
return a;
}
int main() {
int cnt = 0;
ll WW = INF;
scanf("%d%d%d", &n, &m, &P);
for (int i = 1; i <= P; i++) {
int tmp;
scanf("%d", &tmp);
p[tmp] = 1;
}
init();
for (int i = 1; i <= m; i++) {
int u, v;
ll w;
scanf("%d%d%lld", &u, &v, &w);
if (u == v)continue;
addedge(u, v, w);
addedge(v, u, w);
if (u == v)cnt++;
if (u != v)WW = minn(WW, w);
}
int s = -1;
for (int i = 1; i <= n; i++) {
if (p[i] == 0) {
s = i;
break;
}
}
if (n == 2) {
printf("%lld\n", WW);
return 0;
}
if (s == -1) {
printf("-1\n");
return 0;
}
ll anss = prim(s);
printf("%lld\n", ((anss >= INF) ? -1 : anss));
return 0;
}