MNIST数据集转28*28大小的图片
准备
构架一个目录结构即可
在MNIST_IMG文件夹中建两个文件夹分别是TEST和TRAIN,然后每个文件夹里面建名字为0到9 的10个文件夹即可
TRAIN文件夹类似就步粘贴了
代码
测试集
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import cv2
batch_size = 1
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.MNIST(
root='./MNIST_IMG/mnist/',
train=True,
download=True,
transform=transform
)
train_loader = DataLoader(
dataset=train_dataset,
shuffle=True,
batch_size=batch_size
)
test_dataset = datasets.MNIST(
root='./MNIST_IMG/mnist/',
train=False,
download=True,
transform=transform
)
test_loader = DataLoader(
dataset=test_dataset,
shuffle=False,
batch_size=batch_size
)
def save_train_data():
num0 = 0
num1 = 0
num2 = 0
num3 = 0
num4 = 0
num5 = 0
num6 = 0
num7 = 0
num8 = 0
num9 = 0
for batch_idx, data in enumerate(test_loader, 0):
inputs, targets = data
inputs = inputs.numpy()
targets = targets.numpy()
inputs = inputs.reshape(28, 28)*256
if targets == 0:
cv2.imwrite('./MNIST_IMG/TEST/0/' + str(num0) + '.jpg',inputs)
num0=num0+1
if targets == 1:
cv2.imwrite('./MNIST_IMG/TEST/1/' + str(num1) + '.jpg',inputs)
num1=num1+1
if targets == 2:
cv2.imwrite('./MNIST_IMG/TEST/2/' + str(num2) + '.jpg',inputs)
num2=num2+1
if targets == 3:
cv2.imwrite('./MNIST_IMG/TEST/3/' + str(num3) + '.jpg',inputs)
num3=num3+1
if targets == 4:
cv2.imwrite('./MNIST_IMG/TEST/4/' + str(num4) + '.jpg',inputs)
num4=num4+1
if targets == 5:
cv2.imwrite('./MNIST_IMG/TEST/5/' + str(num5) + '.jpg',inputs)
num5=num5+1
if targets == 6:
cv2.imwrite('./MNIST_IMG/TEST/6/' + str(num6) + '.jpg',inputs)
num6=num6+1
if targets == 7:
cv2.imwrite('./MNIST_IMG/TEST/7/' + str(num7) + '.jpg',inputs)
num7=num7+1
if targets == 8:
cv2.imwrite('./MNIST_IMG/TEST/8/' + str(num8) + '.jpg',inputs)
num8=num8+1
if targets == 9:
cv2.imwrite('./MNIST_IMG/TEST/9/' + str(num9) + '.jpg', inputs)
num9=num9+1
print(str(batch_idx) + '.jpg' + '-' + str(targets))
if __name__ == '__main__':
save_train_data()
训练集
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import cv2
batch_size = 1
transform = transforms.Compose([
transforms.ToTensor()
])
train_dataset = datasets.MNIST(
root='./MNIST_IMG/mnist/',
train=True,
download=True,
transform=transform
)
train_loader = DataLoader(
dataset=train_dataset,
shuffle=True,
batch_size=batch_size
)
test_dataset = datasets.MNIST(
root='./MNIST_IMG/mnist/',
train=False,
download=True,
transform=transform
)
test_loader = DataLoader(
dataset=test_dataset,
shuffle=False,
batch_size=batch_size
)
def save_train_data():
num0 = 0
num1 = 0
num2 = 0
num3 = 0
num4 = 0
num5 = 0
num6 = 0
num7 = 0
num8 = 0
num9 = 0
for batch_idx, data in enumerate(train_loader, 0):
inputs, targets = data
inputs = inputs.numpy()
targets = targets.numpy()
inputs = inputs.reshape(28, 28)*256
if targets == 0:
cv2.imwrite('./MNIST_IMG/TRAIN/0/' + str(num0) + '.jpg',inputs)
num0=num0+1
if targets == 1:
cv2.imwrite('./MNIST_IMG/TRAIN/1/' + str(num1) + '.jpg',inputs)
num1=num1+1
if targets == 2:
cv2.imwrite('./MNIST_IMG/TRAIN/2/' + str(num2) + '.jpg',inputs)
num2=num2+1
if targets == 3:
cv2.imwrite('./MNIST_IMG/TRAIN/3/' + str(num3) + '.jpg',inputs)
num3=num3+1
if targets == 4:
cv2.imwrite('./MNIST_IMG/TRAIN/4/' + str(num4) + '.jpg',inputs)
num4=num4+1
if targets == 5:
cv2.imwrite('./MNIST_IMG/TRAIN/5/' + str(num5) + '.jpg',inputs)
num5=num5+1
if targets == 6:
cv2.imwrite('./MNIST_IMG/TRAIN/6/' + str(num6) + '.jpg',inputs)
num6=num6+1
if targets == 7:
cv2.imwrite('./MNIST_IMG/TRAIN/7/' + str(num7) + '.jpg',inputs)
num7=num7+1
if targets == 8:
cv2.imwrite('./MNIST_IMG/TRAIN/8/' + str(num8) + '.jpg',inputs)
num8=num8+1
if targets == 9:
cv2.imwrite('./MNIST_IMG/TRAIN/9/' + str(num9) + '.jpg', inputs)
num9=num9+1
print(str(batch_idx) + '.jpg' + '-' + str(targets))
if __name__ == '__main__':
save_train_data()