mxnet-增加新层(cpp)[代码]

重点

  • mx.operator.CustomOP是无状态的量,即每次外部调用forward,都会重新构造一个新的mx.operator.CustomOP实例。状态可以保存在nn.Block中,利用aux参数传递

调用关系图

Proto2DBlock(nn.Block)-> Proto2DProp(mx.operator.CustomOpProp) -> Proto2D(mx.operator.CustomOp)

代码注释

“`

class Proto2D(mx.operator.CustomOp): #custom Op的实现代码
def init(self, channels,kernelSize):

    return


def forward(self, is_train, req, in_data, out_data, aux):
    ctx = in_data[0].context
    data = in_data[0] #按照Proto2DProp的约定,in_data第一个元素是data(input)
    proj = aux[0] #Proto2DProp的约定,aux第一个元素是project_action_code
    origin_images = aux[1]#Proto2DProp的约定,aux第二个元素是origin_image
    weight = in_data[1]#按照Proto2DProp的约定,in_data第二个元素是weight

    self.assign(in_data[1],"write",weight) #修改weiht
    self.assign(out_data[0],req[0],output) #修改output(参考Proto2DProp约定)
    return

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
    dataIn = in_data[0] #参考Proto2DProp
    dataOut = out_data[0] #参考Proto2DProp
    weight = in_data[1] #参考Proto2DProp

    self.assign(in_grad[0],req[0],dz) #输出第一个元素是传递到下一层的梯度
    self.assign(in_grad[1],req[1],dw) #输出第二个元素是本层weight的权重
    return

声明operator的输入/输出/参数,shape/type inference

@mx.operator.register(“proto2d”) #这一句注册operator的名字,下面会用到
class Proto2DProp(mx.operator.CustomOpProp):
def init(self,channels,kernelSize): #parameters during initialization
super(Proto2DProp,self).init(need_top_grad=True)
self.kernelSize = int(kernelSize)
self.channels = int(channels)
def list_arguments(self): #parameter during forward()
return [“input”,”weight”] #指定Proto2D::forward()的in_data参数是包含两个元素的list,第一个是input,第二个是weight
def list_outputs(self): #output of forward
return [‘output’]
def list_auxiliary_states(self): #aux parameter during forward()
return [‘project_action’,’origin_image_shape’]#指定Proto2D::forward()的aux参数是包含两个元素的list,第一个是project_action,第二个是origin_image_shape
def infer_shape(self, in_shape): #指定每个参数的shape
data_shape = in_shape[0]
weight_shape = in_shape[1]
output_shape = (data_shape[0],weight_shape[0],data_shape[2],data_shape[3]) #
project_action_shape = (1,)
origin_image_shape = (data_shape[0],3, origin_image_size , origin_image_size) #origin image shape
return (data_shape,weight_shape),(output_shape,),(project_action_shape,origin_image_shape,)
def infer_type(self, in_type): #指定每个参数的type
dtype = in_type[0]
return (dtype,dtype),(dtype,),(np.int32,dtype)
def create_operator(self, ctx, in_shapes, in_dtypes): #调用Proto2D
return Proto2D(self.channels,self.kernelSize)

用gluon封装operator

class Proto2DBlock(nn.Block):
def init(self,in_channels, out_channels,kernel_size=3, **kwargs):
super(Proto2DBlock,self).init(**kwargs)
#configure parameters
self.kernelSize = kernel_size
self.channels = out_channels
#learnable parameters
self.weights = self.params.get(“weight”,shape = (out_channels, in_channels, kernel_size, kernel_size)) #define shape of kernel
return
def forward(self,x, *args): #只需要实现forward
ctx = x.context
#mx.nd.Custom()接口的参数
#op_type指定前面注册的op的名字
#position参数包括两部分:input + aux
# input 在 Proto2DProp::list_arguments()中指定
# aux在Proto2DProp::list_auxiliary_states()中指定
#dict参数在Proto2DProp初始化函数中构造
y = mx.nd.Custom(x,self.weights.data(ctx), self.project_action, self.origin_image, channels=self.channels, kernelSize=self.kernelSize, op_type=”proto2d”)
return y
“`

创建方式

Proto2DBlock(80,64) #即gluon封装的接口

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值