MxNet系列——how_to——torch

博客新址: http://blog.xuezhisd.top
邮箱:xuezhisd@126.com


如何将MXNet用作Torch的前后端

本章节描述了如何将MXNet用作Torch的两个主要功能(前端和后端):

  • 使用MXNet.NDArray来调用Torch的张量数学函数。

  • 将Torch的神经网络模块(层)嵌入到MXNet的符号图中。

编译支持Torch的MXNet

  • 参照 官方教程 来安装Torch
    • 如果还没有安装Torch,将配置文件 make/config.mk (Linux) 或 make/osx.mk (Mac) 复制到MXNet根目录中,并命名为 config.mk。取消文件 config.mk 中的两行注释: TORCH_PATH = $(HOME)/torchMXNET_PLUGINS += plugin/torch/torch.mk
    • 此处默认Torch安装在当前用户的主目录下(TORCH_PATH = $(HOME)/torch)。如果Torch没有安装在此目录,将参数 TORCH_PATH 修改成torch的安装目录。
  • 运行命令 make clean && make 来构建可以使用Torch的MXNet。

与张量相关的数学函数

mxnet.th模块支持调用Torch的张量数学函数和mxnet.nd.NDArray一起使用。查看 完整代码

import mxnet as mx
x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
y = mx.th.abs(x)
print y.asnumpy()

x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
mx.th.abs(x, x) # 原地计算
print x.asnumpy()

使用命令 help(mx.th) 获取更多帮助。

现在我们已经支持网页 Torch’s documentation page.上的最常用的函数。如果你发现你需要的函数还没有支持,你可以通过参考已经登记的函数,轻易地将它登记在页面 mxnet_root/plugin/torch/torch_function.cc 上。

Torch 模块 (网络层)

MXNet通过mxnet.symbol.TorchModule模块来支持Torch的神经网络模块。比如,下面的代码定义了一个对MNIST数据库进行分类的3层DNN网络。 (完整代码):

data = mx.symbol.Variable('data')
fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
act1 = mx.symbol.TorchModule(data_0=fc1, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu1')
fc2 = mx.symbol.TorchModule(data_0=act1, lua_string='nn.Linear(128, 64)', num_data=1, num_params=2, num_outputs=1, name='fc2')
act2 = mx.symbol.TorchModule(data_0=fc2, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu2')
fc3 = mx.symbol.TorchModule(data_0=act2, lua_string='nn.Linear(64, 10)', num_data=1, num_params=2, num_outputs=1, name='fc3')
mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')

下面,分析一下上述代码。首先 data = mx.symbol.Variable('data') 定义一个符号变量作为输入的占位符。然后,fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1') 将符号变量data传递给Torch的NN模块。如果你想使用Torch的Criterion作为损失函数,只需将最后一行替换成:

logsoftmax = mx.symbol.TorchModule(data_0=fc3, lua_string='nn.LogSoftMax()', num_data=1, num_params=0, num_outputs=1, name='logsoftmax')
# Torch的标签从1开始
label = mx.symbol.Variable('softmax_label') + 1
mlp = mx.symbol.TorchCriterion(data=logsoftmax, label=label, lua_string='nn.ClassNLLCriterion()', name='softmax')

nn模块的输入数据的命名估规则是 data_i,其中 i = 0 … num_data-1lua_string 是一个用来创建模块对象的单行Lua语句;对于Torch的内建模块,形式如nn.module_name(arguments) 所示。如果你要使用自定义模块,将它放在一个.lua脚本中,然后加载它:当你的脚本返回一个torch.nn对象时,使用命令 require 'module_file.lua 加载它;当你的脚本返回一个torch.nn类时,使用 (require 'module_file.lua')() 加载它。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值