The realisation of Prim is similar to Dijstrala:
Prim:
void Prim(const vector<vector<int>>& cost, vector<int>& path) {
int i, j, pre, n = cost.size();
vector<int> vis(n,0), lowcost(n,INF);
// For both Prim and Dijkstra, lowcost[i] records the shortest distance from the unmarked set to i
for (i = 0; i < n; ++i)
lowcost[i] = cost[0][i];
path[0] = -1, vis[0] = 1, pre = 0; //path is not right in this codes, think about it
for (i = 1; i < n; ++i) {
int minc = INF, p = -1;
for (j = 0; j < n; ++j)
if (0 == vis[j] && lowcost[j] < minc) {
minc = lowcost[j];
p = j;
}
if (INF == minc)
return -1;
vis[p] = 1, path[p] = pre, pre = p;
for (j = 0; j < n; ++j)
if (0 == vis[j] && lowcost[j] > cost[p][j])
lowcost[j] = cost[p][j];
}
}
在Prim中lowercost[i]表示已经marked的点到结点i的最短距离;在Dijstrala中,lowercost[i]表示起点到结点i的最短距离。因此这俩算法可以写成类似的,都先从lowercost里挑最小的,然后区别在于刷新lowercost的方式不一样。当然prim也可以用堆来实现,写法上变化会大一些,后面LeetCode的例子很好。
Python Version:
inf = 0x3ffffff
def Prim(costs):
nodeSize = len(costs)
path = [0 for i in range(nodeSize)]
vis = [False for i in range(nodeSize)]
if nodeSize == 0 or len(costs[0]) != nodeSize:
return path
lowercost = [costs[0][i] for i in range(nodeSize)]
vis[0], path[0], prev = True, 0, 0 # path[ith], the ith+1 node added into the tree
for i in range(1, nodeSize):
minc, minn = inf, -1
for j in range(1, nodeSize):
if (vis[j] == False and lowercost[j] < minc):
minc, minn = lowercost[j], j
if (minn > 0):
vis[minn], minc, prev, path[i] = True, inf, minn, minn
print('add node ' + str(minn + 1) + ' cost: ' + str(lowercost[minn]))
for j in range(nodeSize):
if (vis[j] == False and costs[prev][j] < lowercost[j]):
lowercost[j] = costs[prev][j]
return path
def Dijstrala(costs, start, end):
nodeSize = len(costs)
path = [0 for i in range(nodeSize)]
vis = [False for i in range(nodeSize)]
if nodeSize == 0 or len(costs[0]) != nodeSize:
return path
lowercost = [costs[start][i] for i in range(nodeSize)]
vis[start], path[start], prev = True, -1, start
for i in range(1,nodeSize):
for j in range(nodeSize):
if (vis[j] == False and lowercost[prev]+costs[prev][j]<lowercost[j]):
lowercost[j]=lowercost[prev]+costs[prev][j]
minc, minn = inf, -1
for j in range(nodeSize):
if (lowercost[j] < minc and vis[j] == False):
minc, minn = lowercost[j],j
path[minn],prev,vis[minn] = prev,minn,True
return path,lowercost
if __name__ == '__main__':
costs = [[0, 6, 1, 5, inf, inf],
[6, 0, 5, inf, 3, inf],
[1, 5, 0, 5, 6, 4],
[5, inf, 5, 0, inf, 2],
[inf, 3, 6, inf, 0, 6],
[inf, inf, 4, 2, 6, 0]
]
res = Prim(costs)
print(res)
res = Dijstrala(costs,0,5)
print(res[0])
print(res[1])
-------------------------------------------------------------------------------------------
There are N
cities numbered from 1 to N
.
You are given connections
, where each connections[i] = [city1, city2, cost]
represents the cost to connect city1
and city2
together. (A connection is bidirectional: connecting city1
and city2
is the same as connecting city2
and city1
.)
Return the minimum cost so that for every pair of cities, there exists a path of connections (possibly of length 1) that connects those two cities together. The cost is the sum of the connection costs used. If the task is impossible, return -1.
Example 1:
Input: N = 3, connections = [[1,2,5],[1,3,6],[2,3,1]] Output: 6 Explanation: Choosing any 2 edges will connect all cities so we choose the minimum 2.
Example 2:
Input: N = 4, connections = [[1,2,3],[3,4,4]] Output: -1 Explanation: There is no way to connect all cities even if all edges are used.
Note:
1 <= N <= 10000
1 <= connections.length <= 10000
1 <= connections[i][0], connections[i][1] <= N
0 <= connections[i][2] <= 10^5
connections[i][0] != connections[i][1]
Prim解法:
import heapq
class Solution:
def minimumCost(self, N: int, connections: List[List[int]]) -> int:
if len(connections)<N-1:
return -1
if N==1:
return 0
costMap=collections.defaultdict(list)
for c in connections:
costMap[c[0]].append([c[1],c[2]])
costMap[c[1]].append([c[0],c[2]])
for key in costMap:
costMap[key].sort(key=lambda x: x[1])
visited=[0]*(N+1)
visited[0]=1
que=[[0,1]]
res=0
while que:
cost,city=heapq.heappop(que)
if visited[city]:
continue
visited[city]=1
res+=cost
for nextCity,nextCost in costMap[city]:
heapq.heappush(que,[nextCost,nextCity])
if all(v==1 for v in visited):
return res
return -1
Kruskal解法:
class Solution:
def minimumCost(self, N: int, connections: List[List[int]]) -> int:
if len(connections)<N-1:
return -1
if N==1:
return 0
uf=[i for i in range(N+1)]
def find(x):
if uf[x]!=x:
uf[x]=find(uf[x])
return uf[x]
def union(x,y):
uf[find(x)]=find(y)
res=0
for c1,c2, cost in sorted(connections,key=lambda x:x[2]):
if find(c1)!=find(c2):
union(c1,c2)
res+=cost
rt = find(1)
for i in range(2,N+1):
if (rt != find(i)):
return -1
return res