这题主要是用BFS, 我一开始用的最短路径算法, TLE辽= =
BFS
不用说了,到处都是
class Solution(object):
def numBusesToDestination(self, routes, S, T):
"""
:type routes: List[List[int]]
:type S: int
:type T: int
:rtype: int
"""
if S==T:
return 0
from collections import defaultdict
stop2bus = defaultdict(set)
n_bus = len(routes)
for bus in xrange(n_bus):
for stop in routes[bus]:
stop2bus[stop].add(bus)
bfs = []
bfs += list(zip(stop2bus[S],itertools.repeat(1)))
mark = stop2bus[S]
while len(bfs):
bus, d = bfs.pop(0)
if bus in stop2bus[T]:
return d
buses = set()
for stop in routes[bus]:
buses |= stop2bus[stop]
buses.remove(bus)
bfs += list(zip(buses-mark, itertools.repeat(d+1)))
mark |= buses
return -1
双向BFS
BFS是从起点单向搜, 双向BFS是从终点和起点同时搜.
主要有两个坑:
- 如这篇博文示例的一样, 只有交替搜索才靠谱
交替搜索就是每次扩展数量少的队列, 扩展一次增加一个长度, 两个队列有交集的时候相遇, 这时候算出来的长度就是总长度 - 传统的方法: 一次取一个, 然后判断是否被对方走过. 只适用于bfs第一次只加一个起点和一个终点的情况
比如, 起点(2,1), (1,1), 终点(0,1),(1,1), 先遍历前一个队列, 假设扩展了一个点(0,2), 这时候不管是不是交替算法, 都要扩展反向的(0,1), 发现对方访问过了, 以为最优解是2-0就给返回了,其实只要遍历一个1
出现这种情况的根本原因是, 扩展只能发生在队列没有交集的时候, 即必须保证放入队列的节点没有被对方遍历过,
当起点和终点只有一个的时候, 扩展可以很容易地被while
下面的第一个判断语句终止, 但是当起点有多个的时候, 就不行了
解决的方法有三种:
- 加两个单个的虚拟节点
- while前判断一下, 加入新节点时判断一下
- 用set把每层节点包起来一起判断
参考讨论区@yorkshire 的代码, 可以用set来记录最外层:
class Solution(object):
def numBusesToDestination(self, routes, S, T):
"""
:type routes: List[List[int]]
:type S: int
:type T: int
:rtype: int
"""
if S==T:
return 0
from collections import defaultdict
stop2bus = defaultdict(set)
n_bus = len(routes)
for bus in xrange(n_bus):
for stop in routes[bus]:
stop2bus[stop].add(bus)
bfs_s, bfs_t = stop2bus[S], stop2bus[T]
visited = set()
d = 1
while len(bfs_s) and len(bfs_t):
if bfs_s & bfs_t:
return d
if len(bfs_s)>len(bfs_t):
bfs_s, bfs_t = bfs_t, bfs_s
d += 1
visited |= bfs_s
buses = set()
for b in bfs_s:
for s in routes[b]:
buses |= stop2bus[s]
bfs_s = buses-visited
return -1
也可以用传统的写法, 每次pop一个, 效率上没什么差别:
class Solution(object):
def numBusesToDestination(self, routes, S, T):
"""
:type routes: List[List[int]]
:type S: int
:type T: int
:rtype: int
"""
if S==T:
return 0
from collections import defaultdict
stop2bus = defaultdict(set)
n_bus = len(routes)
for bus in xrange(n_bus):
for stop in routes[bus]:
stop2bus[stop].add(bus)
bfs_s, bfs_t = list(stop2bus[S]), list(stop2bus[T])
mark_s, mark_t = dict(zip(stop2bus[S],itertools.repeat(1))), dict(zip(stop2bus[T],itertools.repeat(1)))
if set(bfs_s) & set(bfs_t):
return 1
while len(bfs_s) and len(bfs_t):
if len(bfs_s)>len(bfs_t):
bfs_s, bfs_t = bfs_t, bfs_s
mark_s, mark_t = mark_t, mark_s
bus = bfs_s.pop(0)
for s in routes[bus]:
for b in stop2bus[s]:
if b in mark_t:
return mark_s[bus]+mark_t[b]
if b not in mark_s:
bfs_s.append(b)
mark_s[b] = mark_s[bus]+1
return -1