import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt
def image_show(images):
images = images.numpy()
images = images.transpose((1, 2, 0))
print(images.shape)
plt.imshow(images)
plt.show()
def main():
train_dataset = datasets.MNIST(root='./datasets', train=False, download=False,
transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
device = torch.device('cuda:0')
# for batch_idx, (inputs, targets) in enumerate(train_loader):
# inputs = inputs.to(device)
# print(inputs.shape)
inputs, targets = next(iter(train_loader))
print(inputs.shape)
print(targets.shape)
images = torchvision.utils.make_grid(inputs)
print(f'images.shape:{images.shape}')
image_show(images)
if __name__=='__main__':
main()
【深度学习】torchvision.utils.make_grid() 拼接图片
最新推荐文章于 2024-07-17 16:39:53 发布