Pytorch中所有模型都是基于Module这个类,也就是说无论是自定义的模型,还是Pytorch中已有的模型,都是这个类的子类,并重写了forward方法。Pytorch中创建模型有几种方法。
继承Module
这是最直接的方法,自己写一个模型继承Module,并重写forward方法。
from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F
class LinearMoudule(Module):
def __init__(self):
super(LinearMoudule,self).__init__()
self.linear_1 = nn.Linear(10,30)
self.linear_2 = nn.Linear(30,5)
def forward(self,x):
x = self.linear_1(x)
x = F.tanh(x)
x = self.linear_2(x)
x = F.sigmoid(x)
return x
使用Sequential
使用Sequential是一种快速构建模型的方法,只需将需要添加的模型放入其构造函数即可。
from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F
module = nn.Sequential(nn.Linear(10,30),
nn.Tanh(),
nn.Linear(30,5),
nn.Sigmoid())
另外还可以用OrderedDict来对每一层模型进行命名。
from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
module = nn.Sequential(OrderedDict(
{'linear_1':nn.Linear(10,30),
'tanh':nn.Tanh(),
'linear_2': nn.Linear(30,5),
'sigmod': nn.Sigmoid()}
))
module = nn.Sequential(OrderedDict(
[('linear_1', nn.Linear(10,30)),
('tanh',nn.Tanh()),
('linear_2', nn.Linear(30,5)),
('sigmod', nn.Sigmoid())
))
#两种写法都可以
ModuleList和ModuleDict
这两个类顾名思义,是分别通过List和Dict两种容器将模块进行包装来创建新模型的。并且这两个类可以通过迭代来访问。
from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F、
class Module_List(nn.Module):
def __init__(self):
super(Module_List, self).__init__()
self.modules = nn.ModuleList([nn.Linear(10, 30),nn.Tanh(),nn.Linear(30,5),nn.Sigmoid()])
def forward(self, x):
for layer in self.modules:
x = layer(x)
return x
class Module_Dict(nn.Module):
def __init__(self):
super(Module_Dict, self).__init__()
self.modules = nn.ModuleDict({'linear_1' : nn.Linear(10,30),
'tanh':nn.Tanh(),
'linear_2': nn.Linear(30,5),
'sigmod': nn.Sigmoid()})
def forward(self, x):
for layer in self.modules:
x = layer(x)
return x
混合使用
前面几种创建模型的方法可以混合使用,来创建更为复杂的模型。
class Bottle(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size):
super(Bottle, self).__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
def forward(self,x):
x = self.conv(x)
x = self.relu(x)
x = self.pool(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.bottle_1 = Bottle(3,6,5)
self.bottle_2 = Bottle(6,16,5)
self.fc = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU())
self.last_fc = nn.Linear(84, 10)
def forward(self,x):
x = self.bottle_1(x)
x = self.bottle_2(x)
x = x.view(-1, 16 * 5 * 5)
x = self.fc(x)
x = self.last_fc(x)
return x