在Mindspore中实现Momentum Queue

在pytorch中构建一个momentum queue一般使用torch.nn.Module.register_buffer函数,但是mindspore中没有类似的注册方法,所以只能新建一个Parameter来保存。对比如下:

# pytorch
self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
# mindspore
randn = np.random.randn
image_queue = Parameter(normalize(Tensor(randn(self.K, config.hash_bit), mstype.float32)), 'image_queue', requires_grad=False)

在更新部分,mindspore框架的GRAPH模式不能直接对参数的slice赋值,因此只能使用矩阵乘法实现:

batch_size = image_feats.shape[0]

ptr = int(self.ptr_queue)
assert self.queue_size % batch_size == 0  # for simplicity

self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
ptr = (ptr + batch_size) % self.queue_size # move pointer

self.ptr_queue[0] = ptr  
# mindspore
# 在模型定义部分写
self.slide = Tensor(np.arange(0, self.K, 1), mstype.int32)

# 在更新方法里写
keys = ops.stop_gradient(k)
batch_size = keys.shape[0]

assert self.K % batch_size == 0
slide = self.slide
mask = logical_and((slide >= self.queue_ptr), (slide < (self.queue_ptr + batch_size)))
slide = cast(nonzero(mask * (slide + 1)).squeeze(), mstype.int32)
scatter_update(self.queue, slide, keys)
assign(self.queue_ptr, (self.queue_ptr + batch_size) % self.K)

实际上就是需要先构建一个选中要替换部分的mask,然后使用mask构建一个slide,最终根据这个slide指定的位置更新queue。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值