1. 问题描述:
给定一个由 n 个点和 m 条边组成的无向连通加权图。设点 1 到点 i 的最短路径长度为 di。现在,你需要删掉图中的一些边,使得图中最多保留 k 条边。如果在删边操作全部完成后,点 1 到点 i 的最短路径长度仍为 di,则称点 i 是一个优秀点。你的目标是通过合理进行删边操作,使得优秀点的数量尽可能大。
输入格式
第一行包含三个整数 n,m,k。接下来 m 行,每行包含三个整数 x,y,w,表示点 x 和点 y 之间存在一条长度为 w 的边。保证给定无向连通图无重边和自环。
输出格式
第一行包含一个整数 e,表示保留的边的数量 (0 ≤ e ≤ k)。第二行包含 e 个不同的 1∼m 之间的整数,表示所保留的边的编号。按输入顺序,所有边的编号从 1 到 m。你提供的方案,应使得优秀点的数量尽可能大。如果答案不唯一,则输出任意满足条件的合理方案均可。
数据范围
对于前五个测试点,2 ≤ n ≤ 15,1 ≤ m ≤ 15。
对于全部测试点,2 ≤ n ≤ 10 ^ 5,1 ≤ m ≤ 10 ^ 5,n − 1 ≤ m,0 ≤ k ≤ m,1 ≤ x,y ≤ n,x ≠ y,1 ≤ w ≤ 10 ^ 9。
输入样例1:
3 3 2
1 2 1
3 2 1
1 3 3
输出样例1:
2
1 2
输入样例2:
4 5 2
4 1 8
2 4 1
2 1 3
3 4 9
3 1 5
输出样例2:
2
3 2
来源:https://www.acwing.com/problem/content/description/3631/
2. 思路分析:
分析题目可以知道我们去掉的边肯定不是到达某个点的最短路径中的边,所以保留最短路径上的边所有点的最短路径是不变的,因为边权大于0所以我们可以使用堆优化版的dijkstra算法求解出从1号点到其余点的最短路径,如何判断出当前的边是否是最短路径上的边呢?当我们求解出从起点到其余点的最短距离之后,遍历每一条边,如果发现dis[j] = dis[i] + w,说明节点i是由节点j更新过来的,所以属于最短路径上的边;当我们求解出了从1号点到其余点的最短距离之后那么可以使用dfs遍历所有的边,找到最短路径上的k条边即可。
3. 代码如下:
import heapq
from typing import List
class Solution:
# 堆优化版的dijkstra算法求解从1号点到其余点的最短距离
def dijkstra(self, n: int, dis: List[int], g: List[List[int]]):
vis = [0] * (n + 10)
q = list()
heapq.heappush(q, (0, 1))
dis[1] = 0
while q:
p = heapq.heappop(q)
if vis[p[1]] == 1: continue
vis[p[1]] = 1
for next in g[p[1]]:
if dis[next[0]] > dis[p[1]] + next[1]:
dis[next[0]] = dis[p[1]] + next[1]
heapq.heappush(q, (dis[next[0]], next[0]))
def dfs(self, u: int, k: int, vis: List[int], dis: List[int], g: List[List[int]], res: List[int]):
vis[u] = 1
for next in g[u]:
# 当前点next[0]是由u更新过来的并且next[0]还未访问过
if vis[next[0]] == 0 and dis[next[0]] == dis[u] + next[1]:
if len(res) < k:
# next[2]为当前边的编号
res.append(next[2])
self.dfs(next[0], k, vis, dis, g, res)
def process(self):
n, m, k = map(int, input().split())
g = [list() for i in range(n + 10)]
for i in range(1, m + 1):
x, y, z = map(int, input().split())
# 因为最终需要输出边的编号需要存储多一个边的编号的信息
g[x].append((y, z, i))
g[y].append((x, z, i))
INF = 10 ** 18
dis = [INF] * (n + 10)
self.dijkstra(n, dis, g)
res = list()
# 使用vis列表标记dfs的过程中哪些点已经被访问了, 每一个点只能够被访问一次
vis = [0] * (n + 10)
self.dfs(1, k, vis, dis, g, res)
print(len(res))
for x in res:
print(x, end=" ")
if __name__ == '__main__':
Solution().process()