安装
pip下安装:
pip install torch-summary
conda下安装:
conda install torch-summary
注:torchsummary与torch-summary是两个不同库!后者是前者的升级版,添加更多功能且解决了部分bug,因此推荐使用torch-summary!
使用
通过nn.Module构建一个模型(以一个简单的LSTM为例):
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_size=1, hidden_size=16, batch_first=True)
self.mlp = nn.Sequential(
nn.Linear(16, 8),
nn.ReLU(),
nn.Linear(8, 1)
)
def forward(self, x):
x = self.lstm(x)[0][:, -1, :]
res = self.mlp(x)
return res
if __name__ == '__main__':
model = LSTM()
input = torch.randn(8, 32, 1)
output = model(input)
使用torch-summary可视化网络:
import torchsummary
model = LSTM()
torchsummary.summary(model, input_size=(32, 1), batch_size=8)
结果输出:
图中可看出模型的层次结构以及各层的参数统计:包括LSTM和Sequential层,分别包含1216、145个参数,其中Sequential层的两个Linear层分别包含136、9个参数。