出现错误AssertionError: size of input tensor and input format are different. tensor shape: (64, 3, 32, 32), input_format: CHW
代码示例:
writer.add_image("output",output,step)
在运行过程中出现错误:
从报错信息可以看出 input tensor的shape是(64,3,32,32),但是imput_format应该是CHW,我们在输入时把batch_size作为一个维度也输入进入了,所以不匹配。add_image只接收单一图像,所以解决这个问题只需要把add_image改成add_images