博客新址: http://blog.xuezhisd.top
邮箱:xuezhisd@126.com
如何将MXNet用作Torch的前后端
本章节描述了如何将MXNet用作Torch的两个主要功能(前端和后端):
-
使用MXNet.NDArray来调用Torch的张量数学函数。
-
将Torch的神经网络模块(层)嵌入到MXNet的符号图中。
编译支持Torch的MXNet
- 参照 官方教程 来安装Torch
- 如果还没有安装Torch,将配置文件
make/config.mk
(Linux) 或make/osx.mk
(Mac) 复制到MXNet根目录中,并命名为config.mk
。取消文件config.mk
中的两行注释:TORCH_PATH = $(HOME)/torch
和MXNET_PLUGINS += plugin/torch/torch.mk
。 - 此处默认Torch安装在当前用户的主目录下(
TORCH_PATH = $(HOME)/torch
)。如果Torch没有安装在此目录,将参数TORCH_PATH
修改成torch的安装目录。
- 如果还没有安装Torch,将配置文件
- 运行命令
make clean && make
来构建可以使用Torch的MXNet。
与张量相关的数学函数
mxnet.th模块支持调用Torch的张量数学函数和mxnet.nd.NDArray一起使用。查看 完整代码:
import mxnet as mx
x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
y = mx.th.abs(x)
print y.asnumpy()
x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
mx.th.abs(x, x) # 原地计算
print x.asnumpy()
使用命令 help(mx.th)
获取更多帮助。
现在我们已经支持网页 Torch’s documentation page.上的最常用的函数。如果你发现你需要的函数还没有支持,你可以通过参考已经登记的函数,轻易地将它登记在页面 mxnet_root/plugin/torch/torch_function.cc
上。
Torch 模块 (网络层)
MXNet通过mxnet.symbol.TorchModule
模块来支持Torch的神经网络模块。比如,下面的代码定义了一个对MNIST数据库进行分类的3层DNN网络。 (完整代码):
data = mx.symbol.Variable('data')
fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
act1 = mx.symbol.TorchModule(data_0=fc1, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu1')
fc2 = mx.symbol.TorchModule(data_0=act1, lua_string='nn.Linear(128, 64)', num_data=1, num_params=2, num_outputs=1, name='fc2')
act2 = mx.symbol.TorchModule(data_0=fc2, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu2')
fc3 = mx.symbol.TorchModule(data_0=act2, lua_string='nn.Linear(64, 10)', num_data=1, num_params=2, num_outputs=1, name='fc3')
mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')
下面,分析一下上述代码。首先 data = mx.symbol.Variable('data')
定义一个符号变量作为输入的占位符。然后,fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
将符号变量data传递给Torch的NN模块。如果你想使用Torch的Criterion作为损失函数,只需将最后一行替换成:
logsoftmax = mx.symbol.TorchModule(data_0=fc3, lua_string='nn.LogSoftMax()', num_data=1, num_params=0, num_outputs=1, name='logsoftmax')
# Torch的标签从1开始
label = mx.symbol.Variable('softmax_label') + 1
mlp = mx.symbol.TorchCriterion(data=logsoftmax, label=label, lua_string='nn.ClassNLLCriterion()', name='softmax')
nn模块的输入数据的命名估规则是 data_i,其中 i = 0 … num_data-1。 lua_string
是一个用来创建模块对象的单行Lua语句;对于Torch的内建模块,形式如nn.module_name(arguments)
所示。如果你要使用自定义模块,将它放在一个.lua
脚本中,然后加载它:当你的脚本返回一个torch.nn对象时,使用命令 require 'module_file.lua
加载它;当你的脚本返回一个torch.nn类时,使用 (require 'module_file.lua')()
加载它。