常用代码示例
设置随机种子
def set_seed(seed):
np.random.seed(seed) # Numpy module.
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
读取图像数据
现有多种方式读取图像数据,这里简单分析两种
pillow
# Import the required libraries
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
# Read the image
image = Image.open('./penguins.jpg')
print(f"image :{image}") # image :<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=700x500 at 0x7F9B11F59550>
print("image:",type(image)) # image: <class 'PIL.JpegImagePlugin.JpegImageFile'>
#将图片转换成np.ndarray格式
img_PIL = np.array(image)
print("img_PIL:",img_PIL.shape) #img_PIL: (500, 700, 3)
# Define a transform to convert the image to tensor
# Converts a PIL Image or numpy.ndarray (H x W x C) in the range
# [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
# if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
# or if the numpy.ndarray has dtype = np.uint8
transform = transforms.ToTensor()
# Convert the image to PyTorch tensor
tensor = transform(image)
# print the converted image tensor
print(tensor.shape)
cv2
# Import the required libraries
import torch
import cv2
import torchvision.transforms as transforms
# Read the image
image = cv2.imread('penguins.jpg') # OpenCV读取的图片是以BGR格式存储的,而不是常见的RGB格式,因此在处理图像时需要注意。
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Define a transform to convert the image to tensor
transform = transforms.ToTensor()
# Convert the image to PyTorch tensor
tensor = transform(image)
# Print the converted image tensor
print(tensor.shape)
参考文献:
- https://pillow.readthedocs.io/en/stable/reference/Image.html
- https://zhuanlan.zhihu.com/p/275701235