项目地址:https://github.com/jmu201521121021/yolo.pytorch
一、mnist数据集处理(数据集转为图片(3通道))--linux/window均适用
1、下载mnist数据集
链接:http://yann.lecun.com/exdb/mnist/
下载如下图中的链接(4个):建立一个mnist文件夹,并将解压后的四个文件放到mnist文件夹里
2、将mnist数据集转为图片
代码如下:
import os
import cv2
import torchvision.datasets.mnist as mnist
import numpy
# 下载的mnist所在位置
root = "./mnist"
train_set = (
mnist.read_image_file(os.path.join(root, 'train-images.idx3-ubyte')),
mnist.read_label_file(os.path.join(root, 'train-labels.idx1-ubyte'))
)
test_set = (
mnist.read_image_file(os.path.join(root,'t10k-images.idx3-ubyte')),
mnist.read_label_file(os.path.join(root,'t10k-labels.idx1-ubyte'))
)
def convert_to_img(style = "train"):
# 判断转化的是训练集还是测试集
if style == "train":
dataset = train_set
else:
dataset = test_set
# 在mnist文件夹下建立train或val(test)文件夹
data_path = root + os.sep + style + os.sep
if not os.path.exists(data_path):
os.mkdir(data_path)
# 在train或test文件夹下创建train.txt或者val.txt
f = open(data_path + style +'.txt', 'w')
"""make ten files"""
makeFiles(data_path)
# 将mnist数据集转化为图片(3通道),图片按数字大小分别存放在10个文件夹中
for i, (img, label) in enumerate(zip(dataset[0], dataset[1])):
img_path = data_path + str(label)[7] + os.sep + str(i) + '.jpg'
new_img = cv2.cvtColor(img.numpy(), cv2.COLOR_GRAY2BGR)
cv2.imwrite(img_path, new_img)
f.write(str(label)[7] + os.sep + str(i) + '.jpg' + ' ' + str(label)[7] + '\n')
f.close()
def makeFiles(data_path):
for i in range(10):
file_name = data_path + str(i)
if not os.path.exists(file_name):
os.makedirs(file_name)
convert_to_img("train")
convert_to_img("val") # test
3、转化结果
mnist文件夹下有train以及val文件夹
train文件夹下有“0 - 9”十个文件夹以及train.txt文件(其中0-9十个文件夹分别存放对应数字的图片)
val.txt中的数据:每行前面是图片名称(包含文件夹),后面是图片属于哪一类(train.txt一样)
二、加载mnist数据集
def get_image_list(data_root, is_train):
"""
Load image's path and label from txt file
Args:
data_root(str): root of dataset
Returns:
img_list(list): all imagenet' image path and lable,
every member of list is dict[str->str, str->float]
eg. {'image_path': '1.jpg', 'label': 0}
"""
assert not data_root is None, "The data_root is None, please use --data_root to set it"
if is_train:
txt_name = "train.txt"
data_root = os.path.join(data_root, "train")
else:
txt_name = "val.txt"
data_root = os.path.join(data_root, "val")
txt_path = os.path.join(data_root, txt_name)
assert os.path.exists(txt_path), "Can not find {}".format(txt_path)
img_lists = []
# 读取生成的.txt文件
with open(txt_path, 'r') as f:
for image_label in f.readlines():
image_label = image_label.strip("\n").split(" ")
image = image_label[0]
label = np.array(image_label[1], dtype=np.float32)
img_lists.append({
# 该图片的真实路径
'image_path': os.path.join(data_root, image),
# 该图片的标签
'label': label
})
return img_lists
三、transform(对于图片的一些预处理函数--基于OpenCV-python)