Pytorch: torchvision.utils.make_grid函数的说明
网格化显示数据
# 环境准备
import numpy as np # numpy数组库
import matplotlib.pyplot as plt # 画图库
import torchvision.datasets as dataset # 公开数据集的下载和管理
import torchvision.transforms as transforms # 公开数据集的预处理库,格式转换
import torchvision
import torch.utils.data as data_utils # 对数据集进行分批加载的工具集
# 2-1 准备数据集
train_data = dataset.MNIST(root="data",
train=True,
transform=transforms.ToTensor(),
download=True)
# 2-1 准备数据集
test_data = dataset.MNIST(root="data",
train=False,
transform=transforms.ToTensor(),
download=True)
# 批量数据读取
train_loader = data_utils.DataLoader(dataset=train_data,
batch_size=64,
shuffle=True)
test_loader = data_utils.DataLoader(dataset=test_data,
batch_size=64,
shuffle=True)
def imshow(img):
# img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0))) # 将【c,h,w】-->【h,w,c】
plt.show()
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print("\n合并成一张三通道灰度图片")
images = torchvision.utils.make_grid(imgs, nrow=8, padding=0)
#保存图片
from torchvision.utils import save_image
save_image(images,'image.png')
#显示图片
print(images.shape)
imshow(images)