前言:类似于keras中的序贯模型,当模型较简单的时候,可以使用torch.nn.Sequential类来实现简单的顺序连接模型。Sequential类是继承自Module类的
Sequential类的定义:
class Sequential(Module): # 继承Module
def __init__(self, *args): # 重写了构造函数
def _get_item_by_idx(self, iterator, idx):
def __getitem__(self, idx):
def __setitem__(self, idx, module):
def __delitem__(self, idx):
def __len__(self):
def __dir__(self):
def forward(self, input): # 重写关键方法forward
Sequential类不同的实现(3种实现)
实现一:
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3),
nn.ReLU(),
nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3),
nn.ReLU()
)
print("运行结果为:")
print(model)
print(model[0])
'''
运行结果为:
Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(3): ReLU()
)
Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
'''
注意:上面实现上存在一个问题,每一个层是没有名称,默认的是以0、1、2、3来命名,从上面的运行结果也可以看出。
实现二:
import torch.nn as nn
from collections import OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3)),
('relu2', nn.ReLU())])
)
print("运行结果为:")
print(model)
print(model[2])
'''
运行结果为:
Sequential(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(relu1): ReLU()
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(relu2): ReLU()
)
Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
'''
注意:从结果中可以看出,虽然每一个层都有了自己的名称,但是并不能够通过名称直接获取层,依然只能通过索引index,即
model[2] 是正确的
model["conv2"] 是错误的
这其实是由它的定义实现的,看上面的Sequenrial定义可知,只支持index访问。
实现三:
import torch.nn as nn
model = nn.Sequential()
model.add_module('conv1', nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3))
model.add_module('relu1', nn.ReLU())
model.add_module('conv2', nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3))
model.add_module('relu2', nn.ReLU())
print("运行结果为:")
print(model)
print(model[2])
'''
运行结果为:
Sequential(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(relu1): ReLU()
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(relu2): ReLU()
)
Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
'''
实现三与keras做法类似,实际上Sequential类并没有定义add_module()方法,实际上这个方法是定义在它的父类Module,Sequential继承而来。