为了提高代码复用,需要以子网的形式生成网络成。paddlepaddle生成子网的方法有两种,见官方文档。但类似LSTM这样的网络,没法直接放入Sequential网络,因为LSTM网络的输出是tuple(官方文档),不能直接传递给下一层网络处理。
这里提供一种解决的思路。
* 新建类,继承LSTM,重新forward函数。
* 在新写的forward函数中调用父类LSTM的forward方法。
* 对父类forward的输出拦截输,进行需要的处理,然后再输出。
例如:
import paddle
#新建自定义类,继承LSTM
class MyLSTM(paddle.nn.LSTM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, inputs):
output,_ = super().forward(inputs) #调用父类的forward
return output[:,-1,:] #去输出的最后序列,传给下一层。因为后面一层不是LSTM层。
class T(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.s = paddle.nn.Sequential(
MyLSTM(5,32),
paddle.nn.BatchNorm(32))
def forward(self, inputs):
return self.s(inputs)
model = T()
paddle.summary(model, (1,20,5))```
---------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===========================================================================
MyLSTM-6 [[1, 20, 5]] [1, 32] 4,992
BatchNorm-6 [[1, 32]] [1, 32] 128
===========================================================================
Total params: 5,120
Trainable params: 4,992
Non-trainable params: 128
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.02
Estimated Total Size (MB): 0.02
---------------------------------------------------------------------------
{'total_params': 5120, 'trainable_params': 4992}
##### 多层堆叠的LSTM,怎么写呢?
示例如下:
```python
import paddle
class MyLSTM(paddle.nn.LSTM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, inputs):
output,_ = super().forward(inputs)
if self.num_directions == 2 :
return output
else:
return output[:,-1,:]
class T(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.s = paddle.nn.Sequential(
MyLSTM(5,32, direction = "bidirectional"), #这里要写bidirectional
MyLSTM(64,8), #注意这里64是上层输出维度的2倍,因为是双向的,后层会影响前层。
paddle.nn.BatchNorm(8))
def forward(self, inputs):
return self.s(inputs)
model = T()
paddle.summary(model, (1,20,5))
---------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===========================================================================
MyLSTM-55 [[1, 20, 5]] [1, 20, 64] 9,984
MyLSTM-56 [[1, 20, 64]] [1, 8] 2,368
BatchNorm-27 [[1, 8]] [1, 8] 32
===========================================================================
Total params: 12,384
Trainable params: 12,352
Non-trainable params: 32
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.05
Estimated Total Size (MB): 0.06
---------------------------------------------------------------------------
{'total_params': 12384, 'trainable_params': 12352}