MXNet: A flexible and efficient library for deep learning.
这是MXNet的官网介绍,“MXNet是灵活且高效的深度学习库”。
MXNet是主流的三大深度学习框架之一:
- TensorFlow:Google支持,其简化版是Keras;
- PyTorch:Facebook支持,其工业版是Caffe2;
- MXNet:中立,Apache孵化器项目,也被AWS选为官方DL平台;
MXNet的优势是,其开发者之一李沐,是中国人,在MXNet的推广中具有语言优势(汉语),有利于国内开发者的学习。同时,推荐李沐录制的教学视频,非常不错。
MXNet的高层接口是Gluon,Gluon同时支持灵活的动态图和高效的静态图,既保留动态图的易用性,也具有静态图的高性能,这也是官网介绍的flexible和efficient的出处。同时,MXNet还具备大量学术界的前沿算法,方便移植至工业界。希望MXNet团队再接再励,在深度学习框架的竞赛中,位于前列。
因此,掌握 MXNet/Gluon 很有必要。
本文以深度学习的多层感知机(Multilayer Perceptrons)为算法基础,数据集选用MNIST,介绍MXNet的工程细节。
本文的源码:https://github.com/SpikeKing/gluon-tutorial
数据集
在虚拟环境(Virtual Env)中,直接使用pip安装MXNet即可:
pip install mxnet
如果下载速度较慢,推荐使用阿里云的pypi源:
-i http://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com
MNIST就是著名的手写数字识别库,其中包含0至9等10个数字的手写体,图片大小为28*28的灰度图,目标是根据图片识别正确的数字。
MNIST库在MXNet中被封装为MNIST类,数据存储于.mxnet/datasets/mnist
中。如果下载MNIST数据较慢,可以选择到MNIST官网下载,放入mnist文件夹中即可。在MNIST类中:
- 参数
train
:是否为训练数据,其中true是训练数据,false是测试数据; - 参数
transform
:数据的转换函数,lambda表达式,转换数据和标签为指定的数据类型;
源码:
# 参数train
if self._train:
data, label = self._train_data, self._train_label
else:
data, label = self._test_data, self._test_label
# 参数transform
if self._transform is not None:
return self._transform(self._data[idx], self._label[idx])
return self._data[idx], self._label[idx]
在MXNet中,数据加载类被封装成DataLoader类,迭代器模式,迭代输出与批次数相同的样本集。在DataLoader中,
- 参数
dataset
:数据源,如MNIST; - 参数
batch_size
:训练中的批次数量,在迭代中输出指定数量的样本; - 参数
shuffle
:是否洗牌,即打乱数据,一般在训练时需要此操作。
迭代器的测试,每次输出样本个数(第1维)与指定的批次数量相同:
for data, label in train_data:
print(data.shape) # (64L, 28L, 28L, 1L)
print(label.shape) # (64L,)
break
在load_data()
方法中,输出训练和测试数据,数据类型是0~1(灰度值除以255)的浮点数,标签类型也是浮点数。
具体实现:
def load_data(self):
def transform(data, label):
return data.astype(np.float32) / 255., label.astype(np.float32)
train_data = DataLoader(MNIST(train=True, transform=transform),
self.batch_size, shuffle=True)
test_data = DataLoader(MNIST(train=False, transform=transform),
self.batch_size, shuffle=False)
return train_data, test_data
模型
网络模型使用MXNet中Gluon的样式:
- 创建
Sequential()
序列,Sequential是全部操作单元的容器; - 添加全连接单元Dense,参数units是输出单元的个数,参数activation是激活函数;
- 初始化参数:
- init是数据来源,Normal类即正态分布,sigma是正态分布的标准差;