统一存储
将MNIST中所有的图片数据不区分索引,存储在data_train、data_test文件夹中。
废话不多说,直接代码展示:
import os
import torch
import torchvision
import torchvision.transforms as transforms
from skimage import io
import torchvision.datasets.mnist as mnist
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='data/',
train=False,
transform=transforms.ToTensor())
root='./data/MNIST/raw/'
train_set=(mnist.read_image_file(os.path.join(root,'train-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root,'train-labels-idx1-ubyte'))
)
test_set = (
mnist.read_image_file(os.path.join(root,'t10k-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root,'t10k-labels-idx1-ubyte'))
)
def convert_to_img(train=True):
if(train):
f=open('./data/train.txt','w')
data_path='./data/data_train/'
if(not os.path.exists(data_path)):
os.makedirs(data_path)
for i,(img,label) in enumerate(zip(train_set[0],train_set[1])):
img_path=data_path+str(i)+'.jpg'
print('train_img_path:', img_path, 'train_img_num:', i)
io.imsave(img_path,img.numpy())
f.write(str(label.item()) + '\n')
f.close()
else:
# f = open(root + 'test.txt', 'w')
f = open('./data/test.txt', 'w')
data_path = './data/data_test/'
if (not os.path.exists(data_path)):
os.makedirs(data_path)
for i, (img, label) in enumerate(zip(test_set[0], test_set[1])):
img_path = data_path + str(i) + '.jpg'
print('test_img_path:', img_path, 'test_img_num:', i)
io.imsave(img_path, img.numpy())
f.write(str(label.item()) + '\n')
f.close()
convert_to_img(True)
convert_to_img(False)
按标签分别存储
按标签存储的意思就是,分别将图片按照标注的标签存放到从0-9的文件夹中,不需要手动建立文件夹。
import os
import torch
import torchvision
import torchvision.transforms as transforms
from skimage import io
import torchvision.datasets.mnist as mnist
import numpy
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='data/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='data/',
train=False,
transform=transforms.ToTensor())
root='./data/MNIST/raw/'
train_set=(mnist.read_image_file(os.path.join(root,'train-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root,'train-labels-idx1-ubyte'))
)
test_set = (
mnist.read_image_file(os.path.join(root,'t10k-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root,'t10k-labels-idx1-ubyte'))
)
def convert_to_img(train=True):
if(train):
data_path='./data/data_train01/'
for i,(img,label) in enumerate(zip(train_set[0],train_set[1])):
img_path=data_path+str(label.item())+'/'
print('train_img_path:',img_path,'img_num:',i)
img_name=img_path+str(i)+'.png'
if (not os.path.exists(img_path)):
os.makedirs(img_path)
io.imsave(img_name, img.numpy())
else:
data_path = './data/data_test01/'
for i, (img, label) in enumerate(zip(test_set[0], test_set[1])):
img_path = data_path +str(label.item())+ '/'
print('test_img_path:', img_path, 'img_numpy:',i)
img_name = img_path + str(i) + '.png'
if (not os.path.exists(img_path)):
os.makedirs(img_path)
io.imsave(img_name, img.numpy())
convert_to_img(True)
convert_to_img(False)
好的,代码部分结束,完成!