注意:
- 图片格式:(w,h)或(长 * 宽)
- pytorch支持的tensor格式:(c,h,w)
1. 先将图片转换成ndarray,然后再转换成tensor。
from PIL import Image
import numpy as np
import torch
image_dir = '0001.png'
img = Image.open(image_dir)
img.show()
img.size # (w, h)
type(img)
img =np.array(img) # (w, h) ---> (h, w, c)
img.shape
img = img.transpose(2, 0, 1) # (h, w, c) ---> (c, h, w)
img.shape
img = torch.tensor(img)
img.shape
2. 使用torchvision,做图像增广。
from PIL import Image
import torch
from torchvision import transforms
image_dir = '0001.png'
img = Image.open(image_dir)
# img.show()
img.size # (w, h)
transforms.ToTensor():
- Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
img_transform = transforms.Compose([transforms.ToTensor()])
img = img_transform(img)
img.shape