在使用深度学习平台时,光会使用其中已定义好的操作有时候是满足不了实际使用的,一般需要我们自己定义新的操作。但是,绝大多数深度平台都是编译好的,很难再次编写。本文以Mxnet为例,官方给出四种定义新操作的方法,
分别调用:
1、mx.operator.CustomOp
2、mx.operator.NDArrayOp
3、mx.operator.NumpyOp
4、使用 C++ 定义底层
并且给出了重新定义softmax层的例子。但是sofetmax操作只有前向操作,也没有参数,与我们通常需要需要使用的情况不符,官方文档也没有一个有参数的中间层例子。在此博主给出了一个重新定义全连接操作的例子,希望能够给大家带来帮助。
# pylint: skip-file
import os
from data import mnist_iterator
import mxnet as mx
import numpy as np
import logging
from numpy import *
class Dense(mx.operator.CustomOp):
def __init__(self, num_hidd