在pytorch中,输入网络的图像的shape=[B,C,H,W].
有时我们需要在网络中对图像张量进行resize操作,这时就要用到transforms.Resize([H,W]) 操作。示例如下:
import cv2
import numpy as np
import torch
from torchvision.transforms import Resize
im1 = cv2.imread("./datasets/frame_0001.png").transpose([2,0,1]) # shape=[C,H,W]
im1_torch = torch.from_numpy(im1.astype(np.float32)).unsqueeze(0) # shape=[B,C,H,W]
# im1_torch可以看作是输入torch神经网络的tensor.
torch_resize = Resize([256,256]) # 定义Resize类对象
im1_resize = torch_resize(im1_torch)
# torchvision.transforms.Resize([H,W])的作用是把最后两个维度resize成[H,W].
# 所以,这对图像的通道顺序有要求。
im1_resize_np = im1_resize.data.cpu().numpy()[0].transpose(1, 2, 0) # shape=[H,W,C]
print(im1.shape)
print(im1_resize.shape)
print(im1_resize_np.shape)
cv2.imwrite("./datasets/frame_0001_resize.jpg",im1_resize_np)
参考: Pytorch transforms.Resize()的简单用法_xiongxyowo的博客-CSDN博客_transforms.resize