Pytorch学习笔记之模型的查看,保存及加载

打印模型

当我们写好一个model后,可以通过打印来查看这个model的每一层的模块。

class Bottle(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size):
        super(Bottle, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self,x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.pool(x)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

        self.bottle_1 = Bottle(3,6,5)
        self.bottle_2 = Bottle(6,16,5)

        self.fc = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                nn.ReLU(),
                                nn.Linear(120, 84),
                                nn.ReLU())
        self.last_fc = nn.Linear(84, 10)

    def forward(self,x):
        x = self.bottle_1(x)
        x = self.bottle_2(x)
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc(x)
        x = self.last_fc(x)
        return x

这是一个写好的模型,如果我们在没有模型源码的情况下,想知道模型细节,只需打印出来即可。

if __name__ == '__main__':
	model= Net()
	print(model)

在这里插入图片描述
打印信息中包含每一层模块的名称和模块的具体细节参数,另外如果是Sequential模块里面的子模块没有名称的话,则用数字0,1,2,3代替。

由于python特性我们可以访问到类里面的成员变量,所以可以很轻松的修改模块。比如,

    model= Net()
    model.last_fc = nn.Linear(84,20)

将最后一层全连接层的输出从10变成了20。

nn.Module还有很多成员函数,对我们操作模型非常有帮助。

add_module()

这个函数用于构建模型时添加子模块。

    model= nn.Sequential()
    model.add_module('linear_1',nn.Linear(10,30))
    model.add_module( 'tanh',nn.Tanh())

children(), modules()

返回模型的子模块细节,和打印模块效果一样。

    model = nn.Sequential(OrderedDict({'linear_1' : nn.Linear(10,30),
                            'tanh':nn.Tanh(),
                            'linear_2': nn.Linear(30,5),
                            'sigmod': nn.Sigmoid()}))
    for child in model.children(): 
       print(child)

在这里插入图片描述

named_children(), named_modules()

返回时除了有子模块外,还有该子模块的名字。

    for name,child in model.named_children():
       print(name,child)

parameters(), named_parameters()

这个函数非常有用,可以返回每个子模块的参数。必要时可以做适当修改。比如修改参数的requires_grad属性。

    model = Net()
    for name,param in model.named_parameters():
       print(name,param)

requires_grad_()

这个函数可以设置每个模块是否需要自动求梯度。是个in-place操作。

state_dict()

这个函数和named_parameters()一样,返回模型的的各个子模型的名字和参数。不同的是这个函数返回的是字典。

    model = Net()
    for name,param in model.state_dict().items():
       print(name,param)

load_state_dict()

用于将已有的参数复制到模型上,用于模型数据恢复。

模型保存与加载

模型的保存与加载一般是通过torch.save函数和torch.load函数来实现,这两个函数分别通过序列化和反序列化来保存和加载模型。实现的方式有两种,第一种是将模型网络结构和参数都保存。

model = Net()
torch.save(model,'./model.pth')
model = torch.load('./model.pth')

另外一种方法则是仅保存参数,

model = Net()
torch.save(model.state_dict(),'./model.pth')
model.load_state_dict(torch.load('./model.pth'))
model = Net()
state_dict = {'state_dict':model.state_dict()}
torch.save(state_dict,'./model.pth')
checkpoint = torch.load('./model.pth')
model.load_state_dict(checkpoint['state_dict'])

如果模型参数是通过url提供的,则可以使用torch,utils.model_zoo提供的load_url()函数来加载参数。

import torch.utils.model_zoo as model_zoo

model = Net()
model.load_state_dict(model_zoo.load_url(URL))
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值