如:HWC -> CHW
import cv2
img = cv2.imread('1.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
print(img.shape)
img = img.transpose(2, 0, 1)
print(img.shape)
# 输出为:
# (768, 1024, 3)
# (3, 768, 1024)
如果使用的是torch类型的:则可以使用permute函数.
img = torch.from_numpy(img).float().permute(2, 0, 1).unsqueeze(0)/255