1.系统环境
硬件环境(Ascend/GPU/CPU): GPU
软件环境:
– MindSpore 版本: 1.7.0
执行模式: 动态图(PYNATIVE_MODE) – Python 版本: 3.7.6
– 操作系统平台: linux
2.报错信息
2.1 问题描述
自定义的loss函数中没有继承nn.Cell,并且没有实现construct函数。导致mindspore不支持。
2.2 报错信息
[mindspore \ccsrc\pipeline \jit\parse\parse . cc :584] ParseStatement] Unsupported statement 'Try'.
2.3 脚本代码
3.根因分析
从报错信息较难看出具体出错的原因。经过调试发现在FaceLoss函数init中不支持diy_loss函数这样的写法。也需要继承nn.cell同时编写construct函数。
4.补充知识
从mindspore文档也发现都需要继承nn.cell。
网络基本单元 Cell 当用户需要自定义网络时,需要继承Cell类,并重写__init__方法和construct方法。损失函数、优化器和模型层等本质上也属于网络结构,也需要继承Cell类才能实现功能,同样用户也可以根据业务需求自定义这部分内容。
5.解决方案
解决方案说明:重写diy_loss函数,继承nn.cell同时编写construct函数。
修改后代码:
class diy_loss(nn.Cell):
def __init__(self,target_emb):
super(diy_loss, self).__init__()
self.uniformreal = ops.UniformReal(seed=2)
self.sum = ops.ReduceSum(keep_dims=False)
self.norm = nn.Norm(keep_dims=True)
self.zeroslike = ops.ZerosLike()
self.concat_op1 = ops.Concat(1)
self.concat_op2 = ops.Concat(2)
self.pow = ops.Pow()
self.reduce_sum = ops.operations.ReduceSum()
self.target_emb = target_emb
self.abs = ops.Abs()
self.reduce_mean = ops.ReduceMean()
def construct(self, adversarial_emb,input_emb,mask_tensor): vert_diff = mask_tensor[:, 1:] - mask_tensor[:, :-1] hor_diff = mask_tensor[:, :, 1:] - mask_tensor[:, :, :-1] vert_diff_sq = self.pow(vert_diff, 2) hor_diff_sq = self.pow(hor_diff, 2) A = self.zeroslike(Tensor(self.uniformreal((3, 1, 112)))) B = self.zeroslike(Tensor(self.uniformreal((3, 112, 1)))) vert_pad = self.concat_op1((vert_diff_sq, A)) hor_pad = self.concat_op2((hor_diff_sq, B)) tv_sum = vert_pad + hor_pad tv = ops.functional.sqrt(tv_sum + 1e-5) tv_final_sum = self.sum(tv) tv_loss = (1e-4) * tv_final_sum print("tv_loss:",tv_loss) prod_sum = self.reduce_sum(adversarial_emb * self.target_emb, (1,)) square1 = self.reduce_sum(ops.functional.square(adversarial_emb), (1,)) square2 = self.reduce_sum(ops.functional.square(self.target_emb), (1,)) denom = ops.functional.sqrt(square1) * ops.functional.sqrt(square2) cos_loss = -(prod_sum / denom) print("dis_loss:", cos_loss) return tv_loss+cos_loss
修改后可以正常运行。