MindSpore 实现unflod和flod

本文写于2022年12月23日。因为MindSpore框架在不断更新,可能你看到这篇文章的时候已经不再适用,或者有更好的实现方式。

unflod的实现

unflod的实现比较简单,因为已经有nn的接口了。实现方法可以参考我的另一篇博文

MindSpore和Python中nn.Unfold的区别_失落的熊熊的博客-CSDN博客_python unfold

flod的实现

flod的实现就有点不那么容易了,因为还没有开发出接口。已经提了issue,并得到了回答,官方说大概要2022年Q1才能完成,并给了我另一个解决方法,使用Col2Im算子。而Col2lm算子的使用也是一个令人头大的事情。

希望添加flod算子 · Issue #I663W0 · MindSpore/mindspore - Gitee.com

官方的文档可以说写的真的非常不清楚,甚至给了我很大的误导。大家可以先看一下官方的文档。

mindspore.Tensor — MindSpore master documentation

mindspore.ops.col2im — MindSpore master documentation

根据我的理解col2im的input需要四个维度。分别是(bs,c,kernel_size*kernel_size,n)。其中bs是batch_size;c是输出的深度,也就是unflod之前的深度;kernel_size是核大小;n为滑动窗口的数量,这个值是计算得出的。

下面给大家我进行测试的一个示例代码,辅助大家对unflod和flod的理解。

import mindspore as ms
import numpy as np
from mindspore import nn
import mindspore.common.dtype as mstype

opReshape = ms.ops.Reshape()
unfold = nn.Unfold(ksizes=[1, 3, 3, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='valid')
x = ms.Tensor(input_data=np.random.rand(1, 2, 4, 4), dtype=mstype.float32) # 1,2,4,4
print(x)
x_unflod = unfold(x) # 1,18,2,2
x_unflod = opReshape(x_unflod,(1,2,9,-1))
[bs,c,h,w]=x.shape
output_size = ms.Tensor(input_data=[h, w], dtype=mstype.int32)
x_flod = ms.ops.col2im(x_unflod, output_size, kernel_size=[3, 3], dilation=[1, 1],
                       padding_value=[0, 0], stride=[1, 1])
print(x_flod)

而正如文章Pytorch unfold和fold_comea23的博客-CSDN博客_pytorch unfold

 所说,只有卷积不重叠的时候unflod和flod才是互逆的,如果想要得到互逆的结果,可以修改参数以实现。

opReshape = ms.ops.Reshape()
unfold = nn.Unfold(ksizes=[1, 3, 3, 1], strides=[1, 3, 3, 1], rates=[1, 1, 1, 1], padding='valid')
x = ms.Tensor(input_data=np.random.rand(1, 1, 6, 6), dtype=mstype.float32) # 1,2,4,4
print(x)
x_unflod = unfold(x) # 1,9,2,2

x_unflod = opReshape(x_unflod,(1,1,9,-1))

[bs,c,h,w]=x.shape
output_size = ms.Tensor(input_data=[h, w], dtype=mstype.int32)
x_flod = ms.ops.col2im(x_unflod, output_size, kernel_size=[3, 3], dilation=[1, 1],
                       padding_value=[0, 0], stride=[3, 3])
print(x_flod)

存在问题:

在使用中,我还是遇到了问题。我在反向传播时用了col2im,然后出现了错误。RuntimeError: Illegal primitive: Primitive Col2Im's bprop not defined.

def construct 中代码如下:

output_size = Tensor(input_data=[h, w], dtype=mstype.int32)
part_ref_rerang = part_ref_rerang_unflod.col2im(output_size, kernel_size=[3, 3], dilation=[1, 1],
padding_value=[1, 1], stride=[1, 1])


报错信息如下:
 

Traceback (most recent call last):
File "main.py", line 125, in
main()
File "main.py", line 108, in main
train_loop(model, train_dataset, loss_ce, optimizer, args) # ________
File "/tmp/pycharm_project_894/train_loop.py", line 80, in train_loop
loss, logits = train_step(data, label)
File "/tmp/pycharm_project_894/train_loop.py", line 70, in train_step
(loss, logits), grads = grad_fn(data, label)
File "/root/anaconda3/envs/mindspore/lib/python3.8/site-packages/mindspore/ops/functional.py", line 455, in inner_aux_grad_fn
return res, grad_weight(aux_fn, weights)(*args)
File "/root/anaconda3/envs/mindspore/lib/python3.8/site-packages/mindspore/ops/composite/base.py", line 530, in after_grad
return grad(fn, weights)(*args, **kwargs)
File "/root/anaconda3/envs/mindspore/lib/python3.8/site-packages/mindspore/common/api.py", line 98, in wrapper
results = fn(*arg, **kwargs)
File "/root/anaconda3/envs/mindspore/lib/python3.8/site-packages/mindspore/ops/composite/base.py", line 517, in after_grad
pynative_executor.grad(grad, fn, weights, grad_position, *args, **kwargs)
File "/root/anaconda3/envs/mindspore/lib/python3.8/site-packages/mindspore/common/api.py", line 819, in grad
self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values()))
RuntimeError: Illegal primitive: Primitive Col2Im's bprop not defined.

问题已经回复了issue,希望后续可以得到一个解决。

希望添加flod算子 · Issue #I663W0 · MindSpore/mindspore - Gitee.com

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值