pytorch dataset图片显示
- plt参数可以是pytorch的tensor,也可以是numpy的array,但是维度必须是x*x*3;
- PIL.Image只能转化为x*x*3的numpy数组,不能转化为torch的tensor;
- torchvision的transform出来的是3*x*x的tensor,如果要画图需要转化为x*x*3用plt显示。
import torch
import torchvision
import argparse
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
dataset = torchvision.datasets.ImageFolder(
root=args.dataset_dir + "/val",
)
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
image, label = dataset[0]
image = np.array(image)
ax[0].imshow(image)
ax[0].set_title(str(label))
ax[0].axis("off")
trans_image = TransformsSimCLR(args.image_size).train_transform(Image.fromarray(image))
ax[1].imshow(trans_image.permute(1, 2, 0))
ax[1].axis("off")
plt.show()