标题【torch加载数据集并显示】
注:以下代码是Juyter notebook实现
torchvision包含以下数据集:
MNIST
COCO(用于图像标注和目标检测)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10
1. 加载MNIST数据集
import torchvision
from torchvision import transforms
import torch
from torchvision.utils import save_image
dataset = torchvision.datasets.MNIST(root = 'data',
train= True,
transform= transforms.ToTensor(),
download= True)
data_loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 100,
shuffle = True)
'''
MNIST函数原型:
MNIST(root, train=True, transform=None, target_transform=None, download=False)
root:下载到哪里,或从哪里读取
train:True为训练集,False为测试集
download:是否从网络下载
save_image函数原型:
torchvision.utils.save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
参考:
https://blog.csdn.net/weixin_43723625/article/details/108159190?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522167939039516800186592443%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=167939039516800186592443&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_ecpm_v1~rank_v31_ecpm-1-108159190-null-null.142^v74^control_1,201^v4^add_ask,239^v2^insert_chatgpt&utm_term=torchvision.utils.save_image%28%29%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E&spm=1018.2226.3001.4187
类似函数:make_grid
参考:https://blog.csdn.net/u012343179/article/details/83007296?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522167939066516800182128063%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=167939066516800182128063&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_ecpm_v1~rank_v31_ecpm-2-83007296-null-null.142^v74^control_1,201^v4^add_ask,239^v2^insert_chatgpt&utm_term=torchvision.utils.make_grid%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E&spm=1018.2226.3001.4187
'''
2. 拼接保存为bmp格式
for x,y in data_loader:
print(x.size(), x.dtype, x.max())
print(y)
save_image(x, 'mnist100.bmp', nrow= 10, pad_value=1) # x为图像的tensor数据,'mnist100.bmp'图像命名,nrow=10行,pad_value=1子图间隔为1
break # 只输出前100个作为展示
输出:
torch.Size([100, 1, 28, 28]) torch.float32 tensor(1.)
tensor([3, 4, 8, 2, 1, 2, 7, 9, 3, 0, 0, 1, 4, 1, 5, 8, 4, 7, 8, 5, 6, 6, 6, 0,
6, 0, 0, 8, 1, 1, 8, 4, 1, 1, 4, 8, 6, 1, 0, 2, 4, 2, 8, 9, 5, 2, 6, 8,
4, 1, 1, 7, 4, 8, 1, 3, 8, 1, 5, 9, 3, 3, 7, 2, 1, 2, 4, 6, 7, 4, 3, 0,
2, 9, 7, 4, 0, 7, 6, 9, 5, 0, 8, 2, 6, 7, 7, 2, 8, 2, 6, 9, 2, 0, 2, 1,
2, 2, 0, 9])
3. 显示保存的图片
import matplotlib.pyplot as plt
import cv2
img = cv2.imread('mnist100.bmp')
plt.imshow(img)
plt.xticks([]) # 去掉横坐标值
plt.yticks([]) # 去掉纵坐标值
plt.show()
输出:
4. 扩展:加载彩色CIFAR10
导入先前的必备库
输出彩色拼图
dataset = torchvision.datasets.CIFAR10(root = 'cifar10data',
train= True,
transform= transforms.ToTensor(),
download= True)
data_loader = torch.utils.data.DataLoader(dataset = dataset,
batch_size = 100,
shuffle = True)
for x,y in data_loader:
print(x.size(), x.dtype, x.max())
print(y)
save_image(x, 'cifar100.jpg', nrow= 10, pad_value=1) # x为图像的tensor数据,'mnist100.bmp'图像命名,nrow=10行,pad_value=1子图间隔为1
break # 只输出前100个作为展示
输出:
torch.Size([100, 3, 32, 32]) torch.float32 tensor(1.)
tensor([6, 4, 5, 3, 1, 4, 5, 9, 3, 7, 3, 9, 6, 4, 7, 3, 4, 0, 0, 8, 3, 0, 5, 2,
3, 9, 1, 0, 7, 9, 7, 7, 1, 8, 1, 2, 8, 9, 1, 8, 9, 0, 0, 2, 2, 2, 5, 2,
1, 2, 1, 3, 9, 8, 0, 3, 0, 7, 7, 0, 8, 0, 1, 3, 8, 8, 3, 9, 3, 3, 8, 6,
5, 5, 1, 5, 6, 2, 3, 9, 5, 5, 5, 2, 6, 4, 8, 7, 3, 4, 1, 9, 9, 3, 0, 7,
6, 7, 5, 3])
显示彩色拼图
import matplotlib.pyplot as plt
import cv2
img = cv2.imread('cifar100.bmp')
plt.imshow(img)
plt.xticks([]) # 去掉横坐标值
plt.yticks([]) # 去掉纵坐标值
plt.show()
输出: