一、自定义lasagne层
比如我要定义一个Highway network层(关于highway network的知识在网上有很多)
1)首先需要继承lasagne的基础层:
可以看到,自定义的层是继承了lasagne.layers.Layer.
2)然后定义Highway network所需要更新的参数值:
def __init__(self, incoming, h_w=lasagne.init.Normal(), h_b=lasagne.init.Normal(), t_w=lasagne.init.Normal(), t_b=lasagne.init.Normal(), **kwargs):
super(HighwayNetwork2D,self).__init__(incoming, **kwargs)
num_filters = self.input_shape[1]
cnn_size = self.input_shape[2]
self.h_w = self.add_param(h_w, (num_filters,cnn_size), name='h_w')
self.h_b=self.add_param(h_b, (cnn_size,), name='h_b')
self.t_w=self.add_param(t_w, (num_filters,cnn_size), name='t_w')
self.t_b=self.add_param(t_b, (cnn_size,), name='t_b')
3)定义实际的操作,在get_output_for()这个函数中进行
def get_output_for(self, input, **kwargs):
# #batch_size
# batch_size=T.arange(input.shape[0])
#H(x)=tanh(W*x+b)
h_x=T.tanh(self.h_w*input+self.h_b)
#t=sigmoid(W*x+b)
t=T.nnet.sigmoid(self.t_w*input+self.t_b)
#z=t*H(x)+(1-t)*x
z=t*h_x+(1-t)*input
return z
4)最后定义输出时候的shape
def get_output_shape_for(self, input_shape):
return (input_shape[0], input_shape[1], input_shape[2])
最后整一个自定义层的代码如下:
class HighwayNetwork2D(lasagne.layers.Layer):
"""
Highway network and use it has 3D
1.z=t*H(x)+(1-t)*x
2.H(x)=tanh(W*x+b)
3.t=sigmoid(W*x+b)
"""
def __init__(self, incoming, h_w=lasagne.init.Normal(), h_b=lasagne.init.Normal(), t_w=lasagne.init.Normal(), t_b=lasagne.init.Normal(), **kwargs):
super(HighwayNetwork2D,self).__init__(incoming, **kwargs)
num_filters = self.input_shape[1]
cnn_size = self.input_shape[2]
self.h_w = self.add_param(h_w, (num_filters,cnn_size), name='h_w')
self.h_b=self.add_param(h_b, (cnn_size,), name='h_b')
self.t_w=self.add_param(t_w, (num_filters,cnn_size), name='t_w')
self.t_b=self.add_param(t_b, (cnn_size,), name='t_b')
def get_output_for(self, input, **kwargs):
# #batch_size
# batch_size=T.arange(input.shape[0])
#H(x)=tanh(W*x+b)
h_x=T.tanh(self.h_w*input+self.h_b)
#t=sigmoid(W*x+b)
t=T.nnet.sigmoid(self.t_w*input+self.t_b)
#z=t*H(x)+(1-t)*x
z=t*h_x+(1-t)*input
return z
def get_output_shape_for(self, input_shape):
return (input_shape[0], input_shape[1], input_shape[2])