本次实践是基于李宏毅老师ML课程中“Hello world” of deep learning章节进行的实验探究。课程中李宏毅老师使用的是keras 2.0.1。可以参考中文文档自行安装。
参考链接:
Keras中文文档
李宏毅老师“Hello world” of deep learning PPT
一张图解释为什么要用Keras
1. Keras可以看做TensorFlow的接口,用Keras就相当于在用TensorFlow。
2. Keras集成了TensorFlow的许多复杂操作,使用起来更简洁。
3. 入门直接上手TensorFlow比较复杂。
数据准备
课程中获取数据的方法是从库中直接load_data
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
但是由于网络限制(不知道是不是因为wall,试了几次都没成功),就直接去官网下了数据,参见mnist数据集下载地址。
该数据下载后得到的是idx格式数据,具体处理方法参考了这篇博客使用Python解析MNIST数据集(IDX文件格式),测试可用的源码如下(规则在注释里写得很详细),该文件保存为load_data.py,在后文中会直接调用。
# encoding: utf-8
"""
对MNIST手写数字数据文件转换为bmp图片文件格式。
数据集下载地址为http://yann.lecun.com/exdb/mnist。
相关格式转换见官网以及代码注释。
========================
关于IDX文件格式的解析规则:
========================
THE IDX FILE FORMAT
the IDX file format is a simple format for vectors and multidimensional matrices of various numerical types.
The basic format is
magic number
size in dimension 0
size in dimension 1
size in dimension 2
.....
size in dimension N
data
The magic number is an integer (MSB first). The first 2 bytes are always 0.
The third byte codes the type of the data:
0x08: unsigned byte
0x09: signed byte
0x0B: short (2 bytes)
0x0C: int (4 bytes)
0x0D: float (4 bytes)
0x0E: double (8 bytes)
The 4-th byte codes the number of dimensions of the vector/matrix: 1 for vectors, 2 for matrices....
The sizes in each dimension are 4-byte integers (MSB first, high endian, like in most non-Intel processors).
The data is stored like in a C array, i.e. the index in the last dimension changes the fastest.
"""
import numpy as np
import struct
import matplotlib.pyplot as plt
# 训练集文件
train_images_idx3_ubyte_file = './data/train-images-idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = './data/train-labels-idx1-ubyte'
# 测试集文件
test_images_idx3_ubyte_file = './data/t10k-images-idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = './data/t10k-labels-idx1-ubyte'
def decode_idx3_ubyte(idx3_ubyte_file):
"""
解析idx3文件的通用函数
:param idx3_ubyte_file: idx3文件路径
:return: 数据集
"""
# 读取二进制数据
bin_data = open(idx3_ubyte_file, 'rb').read()
# 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
offset = 0
fmt_header = '>iiii'
magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)