概要
mnist 数据集链接:http://yann.lecun.com/exdb/mnist/
fashion_mnist:https://github.com/zalandoresearch/fashion-mnist
mnist 已经被用烂了,也太简单了。所以现在准备采用fashion_mnist。
两者的读取方式完全一致。这里以fashion mnist作为例子。
FashionMNIST 是一个替代 MNIST 手写数字集 的图像数据集。 它是由 Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自 10 种类别的共 7 万个不同商品的正面图片。
FashionMNIST 的大小、格式和训练集/测试集划分与原始的 MNIST 完全一致。60000/10000 的训练测试数据划分,28x28 的灰度图片。你可以直接用它来测试你的机器学习和深度学习算法性能,且不需要改动任何的代码。
说白了就是手写数字没有衣服鞋子之类的更复杂。
数据格式和mnist完全一致:
标注编号描述
0:T-shirt/top(T恤)
1:Trouser(裤子)
2:Pullover(套衫)
3:Dress(裙子)
4:Coat(外套)
5:Sandal(凉鞋)
6:Shirt(汗衫)
7:Sneaker(运动鞋)
8:Bag(包)
9:Ankle boot(踝靴)
代码
# -*- coding: utf-8 -*-
from sklearn import neighbors
from read_data import DataUtils
import datetime
import numpy as np
import struct
import matplotlib.pyplot as plt
def read_image(file_name):
'''
:param file_name: 文件路径
:return: 训练或者测试数据
如下是训练的图片的二进制格式
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
'''
file_handle=open(file_name,"rb") #以二进制打开文档
file_content=file_handle.read() #读取到缓冲区中
head = struct.unpack_from('>IIII', file_content, 0) # 取前4个整数,返回一个元组
offset = struct.calcsize('>IIII')
imgNum = head[1] #图片数
width = head[2] #宽度
height = head[3] #高度
bits = imgNum * width * height # data一共有60000*28*28个像素值
bitsString = '>' + str(bits) + 'B' # fmt格式:'>47040000B'
imgs = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组
imgs_array=np.array(imgs).reshape((imgNum,width*height)) #最后将读取的数据reshape成 【图片数,图片像素】二维数组
return imgs_array
def out_image(img):
'''
:param img: 图片像素组
:return:
'''
plt.figure()
plt.imshow(img)
plt.show()
def read_label(file_name):
'''
:param file_name:
:return:
标签的格式如下:
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
'''
file_handle = open(file_name, "rb") # 以二进制打开文档
file_content = file_handle.read() # 读取到缓冲区中
head = struct.unpack_from('>II', file_content, 0) # 取前2个整数,返回一个元组
offset = struct.calcsize('>II')
labelNum = head[1] # label数
bitsString = '>' + str(labelNum) + 'B' # fmt格式:'>47040000B'
label = struct.unpack_from(bitsString, file_content, offset) # 取data数据,返回一个元组
return np.array(label)
def get_data():
# 文件获取
train_image = "./mnist/train-images-idx3-ubyte"
test_image = "./mnist/t10k-images-idx3-ubyte"
train_label = "./mnist/train-labels-idx1-ubyte"
test_label = "./mnist/t10k-labels-idx1-ubyte"
# 读取数据
train_x = read_image(train_image)
test_x = read_image(test_image)
train_y = read_label(train_label)
test_y = read_label(test_label)
print(train_y[0:10])
print(test_y[0:10])
out_image(np.array(test_x[0]).reshape(28, 28))
return train_x,train_y,test_x,test_y
if __name__ == "__main__":
get_data()
结果展示
C:\ProgramData\Anaconda3\python.exe E:/hw/hw0.py
[9 0 0 3 0 2 7 2 5 5]
[9 2 1 1 6 1 4 6 5 7]
Process finished with exit code 0
从结果上看,9——裸靴,图片看正确。解析ok
后面就是采用各种机器学习算法进行分类。