成功安装mindspore后(已经用python -c "import mindspore;mindspore.run_check()" 验证成功安装),但是在使用Parameter出现报错
代码为官方文档1.8中的代码示例
class Loss(MyLoss):
"""Loss"""
def __init__(self, class_num):
super(Loss, self).__init__()
def construct(self, logit, label):
input_x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32), name='x')
input_x = mindspore.Parameter(Tensor(np.zeros((4, 4, 4)), mindspore.int32))
indices = Tensor(np.array([[0], [2]]), mindspore.int32)
updates = Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]]]), mindspore.int32)
# output = ops.scatter_nd_add(input_x, indices, updates, False)
# print(output)
return 0
报错信息