源码中的一行代码:
img = torch.from_numpy(pic.transpose((2, 0, 1)))
再看我写的实验:
import cv2,numpy as np from torchvision import transforms trans=transforms.Compose( [ transforms.ToTensor(), ]) print(np.shape(cv2.imread("test.jpg"))) print(trans(cv2.imread("test.jpg")).size())
输出结果为:
(96, 96, 3)
torch.Size([3, 96, 96])
也就是说从opencv读到的图片是通道数在第三个维度,现在经过ToTensor操作cv2图片变成了torch image类型,也就是通道在第一个维度。接下来就不用自己显式的用tensor的permute方法来执行维度变化了。可以看我另一篇关于permute的用法:PyTorch中permute的用法