之前设置图模式运行模型的时候提示有动态shape算子,要设置set_inputs。于是我进行了如下设置
network.set_inputs(ms.Tensor(shape=[opt.batch_size,2,1,3,opt.sequence_length,17],dtype=ms.float32,init=One()))
net_loss.set_inputs(ms.Tensor(shape=[opt.batch_size,128],dtype=ms.float32,init=One()),ms.Tensor(shape=[opt.batch_size,4],dtype=ms.int32,init=One()))
在经过这样的设置之后,代码能在图模式下成功运行。 但现在模型遇到了点问题,需要将model.train中的dataset_sink_mode设置为False来debug。 但是这样设置后在训练时就会报错。
Traceback (most recent call last):
File "train.py", line 259, in <module>
model.train(opt.epochs, train_ds, callbacks=[ckpoint,loss_monitor],dataset_sink_mode=False)
File "/home/ma-user/.local/lib/python3.7/site-packages/mindspore/train/model.py", line 911, in train
sink_size=sink_size)
File "/home/ma-user/.local/lib/python3.7/site-packages/mindspore/train/model.py", line 91, in wrapper
func(self, *args, **kwargs)
File "/home/ma-user/.local/lib/python3.7/site-packages/mindspore/train/model.py", line 547, in _train
self._train_process(epoch, train_dataset, list_callback, cb_params)
File "/home/ma-user/.local/lib/python3.7/site-packages/mindspore/train/model.py", line 799, in _train_process
outputs = self._train_network(*next_element)
File "/home/ma-user/.local/lib/python3.7/site-packages/mindspore/nn/cell.py", line 586, in __call__
out = self.compile_and_run(*args)
File "/home/ma-user/.local/lib/python3.7/site-packages/mindspore/nn/cell.py", line 963, in compile_and_run
self.compile(*inputs)
File "/home/ma-user/.local/lib/python3.7/site-packages/mindspore/nn/cell.py", line 939, in compile
self._check_compile_dynamic_shape(*inputs)
File "/home/ma-user/.local/lib/python3.7/site-packages/mindspore/nn/cell.py", line 2155, in _check_compile_dynamic_shape
f"For 'set_inputs', the Length of Tensor should be {len_inputs}, but got {len_dynamic_shape_inputs}."
ValueError: For 'set_inputs', the Length of Tensor should be 2, but got 1.
请问出现这种报错的原因可能是什么?不设置dataset_sink_mode为False的是可以跑起来的。
版本如下 mindspore 1.7.0 ascend 910
****************************************************解答*****************************************************
设置为PyNative模式:context.set_context(mode=context.PYNATIVE_MODE)