打印模型
当我们写好一个model后,可以通过打印来查看这个model的每一层的模块。
class Bottle(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size):
super(Bottle, self).__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
def forward(self,x):
x = self.conv(x)
x = self.relu(x)
x = self.pool(x)
return x
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.bottle_1 = Bottle(3,6,5)
self.bottle_2 = Bottle(6,16,5)
self.fc = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU())
self.last_fc = nn.Linear(84, 10)
def forward(self,x):
x = self.bottle_1(x)
x = self.bottle_2(x)
x = x.view(-1, 16 * 5 * 5)
x = self.fc(x)
x = self.last_fc(x)
return x
这是一个写好的模型,如果我们在没有模型源码的情况下,想知道模型细节,只需打印出来即可。
if __name__ == '__main__':
model= Net()
print(model)
打印信息中包含每一层模块的名称和模块的具体细节参数,另外如果是Sequential模块里面的子模块没有名称的话,则用数字0,1,2,3代替。
由于python特性我们可以访问到类里面的成员变量,所以可以很轻松的修改模块。比如,
model= Net()
model.last_fc = nn.Linear(84,20)
将最后一层全连接层的输出从10变成了20。
nn.Module还有很多成员函数,对我们操作模型非常有帮助。
add_module()
这个函数用于构建模型时添加子模块。
model= nn.Sequential()
model.add_module('linear_1',nn.Linear(10,30))
model.add_module( 'tanh',nn.Tanh())
children(), modules()
返回模型的子模块细节,和打印模块效果一样。
model = nn.Sequential(OrderedDict({'linear_1' : nn.Linear(10,30),
'tanh':nn.Tanh(),
'linear_2': nn.Linear(30,5),
'sigmod': nn.Sigmoid()}))
for child in model.children():
print(child)
named_children(), named_modules()
返回时除了有子模块外,还有该子模块的名字。
for name,child in model.named_children():
print(name,child)
parameters(), named_parameters()
这个函数非常有用,可以返回每个子模块的参数。必要时可以做适当修改。比如修改参数的requires_grad属性。
model = Net()
for name,param in model.named_parameters():
print(name,param)
requires_grad_()
这个函数可以设置每个模块是否需要自动求梯度。是个in-place操作。
state_dict()
这个函数和named_parameters()一样,返回模型的的各个子模型的名字和参数。不同的是这个函数返回的是字典。
model = Net()
for name,param in model.state_dict().items():
print(name,param)
load_state_dict()
用于将已有的参数复制到模型上,用于模型数据恢复。
模型保存与加载
模型的保存与加载一般是通过torch.save函数和torch.load函数来实现,这两个函数分别通过序列化和反序列化来保存和加载模型。实现的方式有两种,第一种是将模型网络结构和参数都保存。
model = Net()
torch.save(model,'./model.pth')
model = torch.load('./model.pth')
另外一种方法则是仅保存参数,
model = Net()
torch.save(model.state_dict(),'./model.pth')
model.load_state_dict(torch.load('./model.pth'))
model = Net()
state_dict = {'state_dict':model.state_dict()}
torch.save(state_dict,'./model.pth')
checkpoint = torch.load('./model.pth')
model.load_state_dict(checkpoint['state_dict'])
如果模型参数是通过url提供的,则可以使用torch,utils.model_zoo提供的load_url()函数来加载参数。
import torch.utils.model_zoo as model_zoo
model = Net()
model.load_state_dict(model_zoo.load_url(URL))