Pytorch 笔记 -- model模型

1. 保存载入模型

import torch
torch.save(model,'model_name')  #将网络模型、模型参数全部保存
torch.save(model.state_dict(),'model_name.pkl')  #只保存模型参数

两种载入方式也不相同

torch.load('model_name')

from torchvision import models
model = models.resnet18()   #载入网络
model.load_state_dict(torch.load('model_name.pkl')) #载入参数

2. 查看模型参数

for param_name, param in model.named_parameters():
	print (param_name,param.shape)

3.修改模型参数

param.copy_(input_param)   #本程序待验证,摘自torch.nn.modules.module.py :656行 param 为原来的参数,input_param 为新的参数,两者维度相同

4. CLASS torch.nn.Module

下面记录此类包含函数的使用方法

  • add_module(name, module)
    在当前的模块中增加一个子模块
model.add_module(name, module)
...
  • children()
    返回子模块的迭代器。

查看model下的子模块

from module in model.children()
	pass
  • cpu()
    将所有模型参数和缓冲区移动到CPU。(Moves all model parameters and buffers to the CPU.)

使用cpu载入模型

model.cpu()
  • cuda(device=None)
    将所有模型参数和缓冲区移动到GPU。(Moves all model parameters and buffers to the GPU.)
    在构造优化器之前使用它

使用gpu载入模型

model.cuda()
  • double()
    将所有浮点参数和缓冲区转换为double数据类型。
  • float()
    将所有浮点参数和缓冲区转换为float数据类型。
  • eval()
    设置模型为评估模式,在测试模型之前使用
model.eval()
  • half()
    Casts all floating point parameters and buffers to half datatype.
  • forward(*input)
    定义每次调用时执行的计算。应该被所有子类覆盖。(Defines the computation performed at every call.
    Should be overridden by all subclasses.)
    每次调用网络模型时会执行此函数
    待续。。。
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值