最近在复现YOLACT++的时候有一个需求,就是要求MXNet支持动态的indexing,简言之就是用tensor A作为index去取tensor B的值,A是会变化的,但是A中值为1的数量是固定的。比如:
之所以有这个需求是因为,YOLACT++在forward的时候,选择了100个postive anchor对应的mask coefficients,去给对应的mask proto进行加权,然后去算mask loss。而对于不同的图片,这100个postive anchor在feature map上的位置是一直在变化的,所以选取的mask coefficients也是在变化的(其实就是一种动态的indexing)。YOLACT++本身是用PyTorch实现的,动态indexing对于动态图是很简单的(类似numpy)。
看起来非常简单的一个功能,但是目前MXNet的symbol API(静态图)是不支持的。。这里需要强调的是,虽然tensor A会变化,但是其中值为1的数量是固定的,因此用A去给B做indexing,得到的tensor shape是固定且已知的,所以这个并不违背静态图的要求。
当然,如果使用MXNet的Gluon API(动态图),这个是可以很简单做到的,就和PyTorch一样。但我们用Gluon写的网络,最后出于速度上的考虑,都会hybridize成静态图去跑,所以归根结底,还是要symbol模式下支持这种动态indexing才行。
反复看了官方的文档,发现MXNet symbol下面的take和pick看起来是“唯二”接近于实现这种功能的了,可惜他们接收的index都是一个固定参数,而非tensor,所以不存在动态变化的可能。
mxnet.symbol.take(a=None, indices=None, axis=_Null, mode=_Null, name=None, attr=None, out=None, **kwargs)
"Takes elements from an input array along the given axis. This function slices the input array along a particular axis with the provided indices."
...
mxnet.symbol.pick(data=None, index=None, axis=_Null, keepdims=_Null, mode=_Null, name=None, attr=None, out=None, **kwargs)
"Picks elements from an input array according to the input indices along the given axis."
既然没有直接实现此功能的operator,那么有什么办法可以间接实现吗?目前我想到的一种办法,通过修改原有的topk operator加上简单的变换,可以实现此功能。
怎么做呢?
比如对于tensor A和tensor B,用A对B进行indexing,首先可以已知A里面值为1的个数(假设是100),那么,首先可以求取B的最大值与最小值的gap,得到shape为(1,)的tensor:
max_gap_value = mx.sym.max(mask_data) - mx.sym.min(mask_data) + 1 # (1,)
然后,将这个gap value乘(broadcast multiply)到tensor A上(这样一样来,A中为1的位置乘以了gap value,而0的位置仍然为0),得到max_gap_arr.
max_gap_arr = A * max_gap_value
将max_gap_arr加到tensor B上,得到 B_add_gap
B_add_gap = B + max_gap_arr
再对B_add_gap取topk(令k=100),就可以把tensor A中值为1对应位置的B取出来了。
B_add_gap_pick = F.topk(B_add_gap, k=100, axis=-1, do_sort=0)
由于得到的B_add_gap_pick是加了max_gap_value的,所以需要再减去,得到原始的值。
proto_coef = B_add_gap_pick - max_gap_value
经过上面一番操作,就可以间接实现用tensor A给tensor B动态做indexing了。OK,曲线救国完毕?
然而,事情并没有那么简单。MXNet默认的topk operator有一个很“讨厌”的性质,会让上面的美梦化为泡影。
回到topk这个op上来,topk顾名思义,是用来取大小排在前k个元素的。
比如对于下面的tensor
我们令k=3,取topk(默认情况下,topk是针对axis=-1的,此例即“列”),可以得到
仔细看会发现,topk得到的结果,默认会把每行前k大的元素从大到小输出(相当于把这个k个元素做了一个sorting),这样似乎并无不妥。然而,对于上面的“曲线救国”方法,topk(倒数第二步)默认的sorting会把indexing的信息全丢了!
换言之,topk经过sort输出之后,我们无从得知output的每个元素到底对应着tensor A的哪个index了。要知道,tensor A不仅仅要对B进行indexing(得到B_out),还要对C做indexing(得到C_out),而B_out和C_out各个元素是有对应关系的,这个对应关系正是通过tensor A来维系的。而上面的曲线救国策略,topk输出时一旦sort了,这个对应关系就彻底乱掉了。
说了这么多,其实就是想说,需要让MXNet的topk在输出时不做sort。对于上面的tensor,执行topk且不做sort,我们期望的结果是:
想要实现这个功能,只需要添加一个do_sort的参数选项(默认为True),并修改topk的源码使得do_sort为False时,输出数值按照index的顺序排列,而不是数值大小降序排列。。这样一来,B_out和C_out的对应关系就得以保留了。
我修改后的一个版本:topk with do_sort param