报错提示:assert len(tensor.shape) == len(AssertionError: size of input tensor and input format are different. tensor shape: (196608,), input_format: CHW
理解and解决办法:
报错提示的意思就是说, input tensor的维度是(196608,),但是 input_format的格式需要是 CHW,两者不匹配。
我这里的话只要不进行torch.flatten()这步操作,其实是可以正常显示的。展平之后数据格式发生变化了,已经不能和所要求的格式匹配上,所以会出现报错。