在pytorch中构建一个momentum queue一般使用torch.nn.Module.register_buffer函数,但是mindspore中没有类似的注册方法,所以只能新建一个Parameter来保存。对比如下:
# pytorch
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
# mindspore
randn = np.random.randn
image_queue = Parameter(normalize(Tensor(randn(self.K, config.hash_bit), mstype.float32)), 'image_queue', requires_grad=False)
在更新部分,mindspore框架的GRAPH模式不能直接对参数的slice赋值,因此只能使用矩阵乘法实现:
batch_size = image_feats.shape[0]
ptr = int(self.ptr_queue)
assert self.queue_size % batch_size == 0 # for simplicity
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
ptr = (ptr + batch_size) % self.queue_size # move pointer
self.ptr_queue[0] = ptr
# mindspore
# 在模型定义部分写
self.slide = Tensor(np.arange(0, self.K, 1), mstype.int32)
# 在更新方法里写
keys = ops.stop_gradient(k)
batch_size = keys.shape[0]
assert self.K % batch_size == 0
slide = self.slide
mask = logical_and((slide >= self.queue_ptr), (slide < (self.queue_ptr + batch_size)))
slide = cast(nonzero(mask * (slide + 1)).squeeze(), mstype.int32)
scatter_update(self.queue, slide, keys)
assign(self.queue_ptr, (self.queue_ptr + batch_size) % self.K)
实际上就是需要先构建一个选中要替换部分的mask,然后使用mask构建一个slide,最终根据这个slide指定的位置更新queue。