1.容器
Containers
nn.Sequential:按照顺序包装多个网络层
nn.ModuleList:像python的list一样包装多个网络层
nn.ModuleDict:像python的dict一样包装多个网络层
2.容器之Sequential
features
classifier
Conv
pool1
Conv2
Pool2
fc2
fc3
fc4
import torch
import torchvision
import torch. nn as nn
from collections import OrderedDict
class LeNetSequential ( nn. Module) :
def __init__ ( self, classes) :
super ( LeNetSequential, self) . __init__( )
self. features = nn. Sequential(
nn. Conv2d( 3 , 6 , 5 ) ,
nn. ReLU( ) ,
nn. MaxPool2d( kernel_size= 2 , stride= 2 ) ,
nn. Conv2d( 6 , 16 , 5 ) ,
nn. ReLU( ) ,
nn. MaxPool2d( kernel_size= 2 , stride= 2 ) , )
self. classifier = nn. Sequential(
nn. Linear( 16 * 5 * 5 , 120 ) ,
nn. ReLU( ) ,
nn. Linear( 120 , 84 ) ,
nn. ReLU( ) ,
nn. Linear( 84 , classes) , )
def forward ( self, x) :
x = self. features( x)
x = x. view( x. size( ) [ 0 ] , - 1 )
x = self. classifier( x)
return x
class LeNetSequentialOrderDict ( nn. Module) :
def __init__ ( self, classes) :
super ( LeNetSequentialOrderDict, self) . __init__( )
self. features = nn. Sequential( OrderedDict( {
'conv1' : nn. Conv2d( 3 , 6 , 5 ) ,
'relu1' : nn. ReLU( inplace= True ) ,
'pool1' : nn. MaxPool2d( kernel_size= 2 , stride= 2 ) ,
'conv2' : nn. Conv2d( 6 , 16 , 5 ) ,
'relu2' : nn. ReLU( inplace= True ) ,
'pool2' : nn. MaxPool2d( kernel_size= 2 , stride= 2 ) ,
} ) )
self. classifier = nn. Sequential( OrderedDict( {
'fc1' : nn. Linear( 16 * 5 * 5 , 120 ) ,
'relu3' : nn. ReLU( ) ,
'fc2' : nn. Linear( 120 , 84 ) ,
'relu4' : nn. ReLU( inplace= True ) ,
'fc3' : nn. Linear( 84 , classes) ,
} ) )
def forward ( self, x) :
x = self. features( x)
x = x. view( x. size( ) [ 0 ] , - 1 )
x = self. classifier( x)
return x
net = LeNetSequentialOrderDict( classes= 2 )
fake_img = torch. randn( ( 4 , 3 , 32 , 32 ) , dtype= torch. float32)
output = net( fake_img)
print ( net)
print ( output)
nn.Sequential是nn.module的容器,用于按顺序包装一组网络层 顺序性:各网络层之间严格按照顺序构建 自带forward():通过for循环按照顺序取出nn.Sequential(模型1,模型2,。。)(x)中的模型,将模型带入x,以x= 模型1(x),x=模型2(x)这样的形式迭代数据x,得出最后的结论。
3、容器之ModuleList
nn.moduleList是nn.module的容器,用于包装一组网络层,以迭代的方式调用网络层,主要方法是:
append():再ModuleList后面添加网络层 entend():拼接两个ModuleList insert()指定在ModuleList位置中插入网络层
class ModuleList ( nn. Module) :
def __init__ ( self) :
super ( ModuleList, self) . __init__( )
self. linears = nn. ModuleList( [ nn. Linear( 10 , 10 ) for i in range ( 20 ) ] )
def forward ( self, x) :
for i, linear in enumerate ( self. linears) :
x = linear( x)
return x
net = ModuleList( )
print ( net)
fake_data = torch. ones( ( 10 , 10 ) )
output = net( fake_data)
print ( output)
4、容器之ModuleDict
5、容器总结:
nn.Sequential:顺序性,各网络层之间严格按顺序执行,常用于block构建 nn.ModuleList:用于大量重复网络构建,通过for重复实现循环构建 nn.ModuleDict:索引性,常用于可选择的网络层