深度框架 MXNet/Gluon 初体验

本文介绍了MXNet深度学习框架及其Gluon接口,强调其灵活性和效率。通过多层感知机在MNIST数据集上的应用,详细阐述了数据加载、模型构建、模型可视化、训练过程以及如何利用GPU加速训练。提供了一个完整的MXNet/Gluon实战教程,帮助读者快速掌握该框架。
摘要由CSDN通过智能技术生成

MXNet: A flexible and efficient library for deep learning.

这是MXNet的官网介绍,“MXNet是灵活且高效的深度学习库”。

MXNet是主流的三大深度学习框架之一:

  • TensorFlow:Google支持,其简化版是Keras
  • PyTorch:Facebook支持,其工业版是Caffe2
  • MXNet:中立,Apache孵化器项目,也被AWS选为官方DL平台;

MXNet的优势是,其开发者之一李沐,是中国人,在MXNet的推广中具有语言优势(汉语),有利于国内开发者的学习。同时,推荐李沐录制的教学视频,非常不错。

MXNet的高层接口是Gluon,Gluon同时支持灵活的动态图和高效的静态图,既保留动态图的易用性,也具有静态图的高性能,这也是官网介绍的flexibleefficient的出处。同时,MXNet还具备大量学术界的前沿算法,方便移植至工业界。希望MXNet团队再接再励,在深度学习框架的竞赛中,位于前列。

MXNet

因此,掌握 MXNet/Gluon 很有必要。

本文以深度学习的多层感知机(Multilayer Perceptrons)为算法基础,数据集选用MNIST,介绍MXNet的工程细节。

本文的源码https://github.com/SpikeKing/gluon-tutorial


数据集

在虚拟环境(Virtual Env)中,直接使用pip安装MXNet即可:

pip install mxnet

如果下载速度较慢,推荐使用阿里云的pypi源:

-i http://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com

MNIST就是著名的手写数字识别库,其中包含0至9等10个数字的手写体,图片大小为28*28的灰度图,目标是根据图片识别正确的数字。

MNIST库在MXNet中被封装为MNIST类,数据存储于.mxnet/datasets/mnist中。如果下载MNIST数据较慢,可以选择到MNIST官网下载,放入mnist文件夹中即可。在MNIST类中:

  • 参数train:是否为训练数据,其中true是训练数据,false是测试数据;
  • 参数transform:数据的转换函数,lambda表达式,转换数据和标签为指定的数据类型;

源码:

# 参数train
if self._train:
    data, label = self._train_data, self._train_label
else:
    data, label = self._test_data, self._test_label

# 参数transform
if self._transform is not None:
    return self._transform(self._data[idx], self._label[idx])
return self._data[idx], self._label[idx]

在MXNet中,数据加载类被封装成DataLoader类,迭代器模式,迭代输出与批次数相同的样本集。在DataLoader中,

  • 参数dataset:数据源,如MNIST;
  • 参数batch_size:训练中的批次数量,在迭代中输出指定数量的样本;
  • 参数shuffle:是否洗牌,即打乱数据,一般在训练时需要此操作。

迭代器的测试,每次输出样本个数(第1维)与指定的批次数量相同:

for data, label in train_data:
    print(data.shape)  # (64L, 28L, 28L, 1L)
    print(label.shape)  # (64L,)
    break

load_data()方法中,输出训练和测试数据,数据类型是0~1(灰度值除以255)的浮点数,标签类型也是浮点数。

具体实现:

def load_data(self):
    def transform(data, label):
        return data.astype(np.float32) / 255., label.astype(np.float32)
    train_data = DataLoader(MNIST(train=True, transform=transform),
                            self.batch_size, shuffle=True)
    test_data = DataLoader(MNIST(train=False, transform=transform),
                           self.batch_size, shuffle=False)
    return train_data, test_data

模型

网络模型使用MXNet中Gluon的样式:

  1. 创建Sequential()序列,Sequential是全部操作单元的容器;
  2. 添加全连接单元Dense,参数units是输出单元的个数,参数activation是激活函数;
  3. 初始化参数:
    • init是数据来源,Normal类即正态分布,sigma是正态分布的标准差;
  • 20
    点赞
  • 113
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CarolineSpike

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值