1 报错描述
1.1 系统环境
Hardware Environment(Ascend/GPU/CPU): Ascend Software Environment: -- MindSpore version (source or binary): 1.6.0 -- Python version (e.g., Python 3.7.5): 3.7.6 -- OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 4.15.0-74-generic -- GCC/Compiler version (if compiled from source):
1.2 基本信息
1.2.1 脚本
根据标杆torch算子NLLLoss的用例,编写输入为(N, C, d1, d2)的用例,脚本如下:
01 loss = nn.NLLLoss()
02 input = torch.randn(5, 4, 8, 8)
03 m = nn.LogSoftmax(dim=1)
04 target = torch.empty(5, 8, 8, dtype=torch.long).random_(0, 4)
05 loss = loss(m(input), target)
06 print('torch_loss',loss)
07
08 m = mn.LogSoftmax(axis=1)
09 loss = ops.NLLLoss()
10 input = Tensor(np.random.randn(5, 4,8,8), mindspore.float32)
11 labels = Tensor([1, 0,1, 1], mindspore.int32)
12 weight = Tensor(np.random.rand(5,8,8), mindspore.float32)
13 loss, weight = loss(m(input), labels, weight)
14 print('mindspore_loss:',loss)
复制
1.2.2 报错
这里报错信息如下:
Traceback (most recent call last):
File demo.py, line 13, in <module>
loss, weight = loss(m(input), labels, weight)
…
File "/lib/python3.7/site-packages/mindspore/_checkparam.py", line 238, in check_int
return check_number(arg_value, value, rel, int, arg_name, prim_name)
File " /lib/python3.7/site-packages/mindspore/_checkparam.py", line 168, in check_number
raise type_except(f'{prim_info} should be {arg_type.__name__} and must {rel_str}, '
ValueError: `x rank` in `NLLLoss` should be int and must in [1, 2], but got `4` with type `int
复制
原因分析
在MindSpore 1.6版本,利用对标算子的用例编写输入为(N, C, d1, d2)的用例。先看报错信息,在ValueError中,写到x rank in NLLLos should be int and must in [1, 2], but got 4 with type int,意思是传的NLLLoss的x_rank参数应该为int,而且应该在[1, 2]之间,但是你传进去的是int类型的4,由报错行数line13,检查传入数据可知我们传入了为4维的(5, 4, 8, 8),而torch传入该类型数据能支持,这是由于目前MindSpore暂不支持(N,C,d1,d2,...,dK) with K≥1类型,这点在旧版本的PyTorch与MindSpore API映射对比中描述为功能一致,此处文档有误(见下图),目前标杆算子和MindSpore框架支持的NLLLoss算子功能存在一定差异,MindSpore目前只支持 shape为(N,C)的数据(如下图所示)。
在新版本中的描述中已经进行修改。参考链接为:比较与torch.nn.NLLLoss的功能差异 — MindSpore master documentation 。
2 解决方法
基于上面已知的原因,该算子存在部分输入不支持的情况,目前需要用户自己封装,该操作会对用户带来一定的困扰,我们将在后续统一考虑这种需求。
3 总结
定位报错问题的步骤:
1、找到报错的用户代码行:*loss, weight = loss(m(input), labels, weight)*;
2、 根据日志报错信息中的关键字,缩小分析问题的范围:*loss, weight = loss(m(input), labels, weight)*;
3、需要重点关注变量定义、初始化的正确性。
4 参考文档
4.1 NLLLoss算子API介绍