在tensorflow中,tf.map_fn()高阶函数(high-level function),和在python中的高阶函数意义相似,其也是将函数当成参数传入,以实现一些有趣的,有用的操作,函数原型为
map_fn(
fn,
elems,
dtype=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)
其中的fn是一个可调用的(callable)函数,就是我们图中的function,一般会使用lambda表达式表示。elems是需要做处理的Tensors,TF将会将elems从第一维展开,进行map处理。主要就是那么两个,其中dtype为可选项,但是比较重要,他表示的是fn函数的输出类型,如果fn返回的类型和elems中的不同,那么就必须显式指定为和fn返回类型相同的类型。下面给出一个retinanet的例子:
def _filter_detections(args):
boxes = args[0]
classification = args[1]
other = args[2]
return filter_detections(
boxes,
classification,
other,
nms &#