pytorch模型创建
模型构建两要素:构建子模块和拼接子模块
构建子模块
是在__init__()
函数中实现的。例如构建一个LeNet网络
class LeNet(nn.Module):
def __init__(self, classes):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
拼接子模块
是在forward
函数中实现的,同时是针对LeNet网络
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
nn.Module
构建子模块的时候发现都是基于nn.Module
实现的
例如上面的LeNet
网络,在开始的是时候class LeNet(nn.Module):
表明LeNet
是继承nn.Module
然后super(LeNet, self).__init__()
即实现父类函数调用的功能,调用nn.Module
的init()
函数,实现初始化
接着到达Conv2d
卷积层,因此Conv2d这个是继承的module的属性,
在赋值的数据类型,是nn.module
还是nn.Parameter
,分别存储在parameters
和modules
中,从而加载Conv2d