预训练模型参数的操作-pytorch

本文涉及操作权重参数常用API的解释,以及初始化权重的常用类型和方法

1、torch.load()
2、model.state_dict()
3、model.named_children()
4、model.children()
5、model.named_parameters()
6、torch.save()
7、model.load_state_dict()
8、初始化权重
9、net.parmeters()

import  torch
dir = r"rnet34.pth"

1、torch.load()

# 第一个参数为权重文件路径,第二个固定输入
model = torch.load(dir,map_location="cpu")

2、model.state_dict()

# 这个api的类型为有序字典,字典的键为网络层结构(conv,bn,bias,fc...),键对应的值为一个tensor,也就是该层结构的权重参数
# BN层也会有参数,这里的参数记录的是前一层结构的均值和方差
for k,v in model.state_dict().items():
    print(k, v)
# 选取全连接层部分的打印结果
fc.weight tensor([[-0.0240, -0.0264,  0.0397,  ..., -0.0377,  0.0322, -0.0025],
        [ 0.0402,  0.0193,  0.0344,  ...,  0.0114,  0.0208, -0.0120],
        [-0.0332,  0.0265, -0.0157,  ...,  0.0344, -0.0008,  0.0077],
        [ 0.0416, -0.0047,  0.0206,  ...,  0.0150,  0.0156, -0.0328],
        [ 0.0120, -0.0226,  0.0252,  ..., -0.0327, -0.0088, -0.0042]])
fc.bias tensor([-0.0003,  0.0042,  0.0223, -0.0133,  0.0210])

3、model.named_children()

# 这个api的类型为生成器,每一次的迭代生成一个元组,元组的第一个元素为层结构名称,第二个元素为层具体结构
for i in model.named_children():
    print(type(i))
    for j in i:
        print(j)
    break
# 打印循环一次的结果
<class 'tuple'>
conv1
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

4、model.children()

# 只输出具体层结构,没有层结构名称,类比第三部分
for layer in model.children():
    print(layer)
# 打印部分层结构
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)
MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

5、model.named_parameters()

# 这个api的类型为生成器,每一次的迭代生成一个元组,第一个元素为层结构名称,第二个元素为层包含权重参数,其类型为<class 'torch.nn.parameter.Parameter'>
for i in model.named_parameters():
    print(type(i))
    for j in i:
        print(j)
    break
# 选取全连接层部分的打印结果
<class 'tuple'>
fc.weight
Parameter containing:
tensor([[-0.0240, -0.0264,  0.0397,  ..., -0.0377,  0.0322, -0.0025],
        [ 0.0402,  0.0193,  0.0344,  ...,  0.0114,  0.0208, -0.0120],
        [-0.0332,  0.0265, -0.0157,  ...,  0.0344, -0.0008,  0.0077],
        [ 0.0416, -0.0047,  0.0206,  ...,  0.0150,  0.0156, -0.0328],
        [ 0.0120, -0.0226,  0.0252,  ..., -0.0327, -0.0088, -0.0042]],
       requires_grad=True)

6、torch.save() 三种形式

# 用于保存模型,有两种保存方式,第一种可以保存网络结构,第二种只保存权重参数
# 注意,函数第二个参数必须以.pth或.pt结尾
torch.save(model,"model.pth")

torch.save(dict,"model.pth")
dict = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'model': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                'scheduler': lr_scheduler.state_dict(),
            }

torch.save(model.state_dict(),"model.pth")

7、model.load_state_dict()

# 加载训练好的模型
import torch
from model import resnet34

dir = r"rnet34.pth"
model_weight = torch.load(dir,map_location="cpu")
net = resnet34(num_classes=5)
net.load_state_dict(model_weight.state_dict())

8、权重参数初始化-----torch.nn.init函数

# 参数初始化主要使用这个函数,下面例子简单介绍
import torch
y = torch.randn((3,4))
print(y)
torch.nn.init.constant_(y,3.)
print(y)

tensor([[ 0.6474, -0.2313, -1.5161, -0.9440],
        [-0.5582, -0.6491, -0.6196, -0.3878],
        [ 0.8618, -0.8447, -0.9571,  0.0944]])
tensor([[3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 3., 3., 3.]])
# 实操
# 简单来说,就是将参数中的tensor提取出来,之后调用torch.nn.init函数
# 注意,可选的初始化类型有好多
import torch


dir = r"rnet34.pth"
model = torch.load(dir,map_location="cpu")
model_dict = model.state_dict()


# print(model_dict.keys())  # 打印层结构名称,方便后续筛选
for k,v in model_dict.items():
    if "conv" in k:
        torch.nn.init.kaiming_normal_(v.data,mode="fan_in")
    else:
        continue

9、net.parmeters()

可以看前向传播部分代码,就知道网络参数包括哪几部分,另外,可以将参数是否需要更新加进去

# 输出模型的参数,可以类比第五部分。类型为生成器,进行循环并且将值放进一个列表中,可以作为优化器的参数
from model import resnet34


net = resnet34(num_classes=5)
for parmeters in net.parameters():
    print(parmeters)
    parmeters.requires_grad = True
    # 通过这个布尔赋值冻结部分参数不被更新
    break
# 打印部分输出结果
Parameter containing:
tensor([[[[ 4.1175e-02,  9.4361e-03,  6.4566e-03,  ...,  1.6439e-03,
           -1.8168e-02, -4.9073e-02],
          [-1.0357e-02, -1.4535e-02, -2.4603e-02,  ...,  2.5078e-02,
           -5.8141e-03,  1.2007e-02],
          [-3.3603e-02, -2.7614e-03, -5.5701e-03,  ..., -3.7980e-02,
            4.4938e-02,  6.4760e-03],
          ...,
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值