PyTorch基础4——加载模型权重

加载模型权重

包括
1 加载完全的模型权重
2 加载某一层的模型权重
3 根据tensor形状加载模型权重


from torch import nn
import torch

# 定义一个网络
class Model(nn.Module):
    def __init__(self,class_num,input_channel=3):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channel, out_channels=32, kernel_size=3) #卷积
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # 池化
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5) #卷积
        self.dropout = nn.Dropout2d(p=0.1) # dropout
        self.adaptive_pool = nn.AdaptiveMaxPool2d((1, 1)) #全局池化
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64, 32) #线性层
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(32, class_num) #最终分了多少个类
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.dropout(x)
        x = self.adaptive_pool(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        y = self.sigmoid(x)
        return y

net = Model(class_num=5,input_channel=1)

# 读取权重文件,权重文件的本质是一个字典
# key是网络每一层的变量名, value是具体的张量值
weight_dict = torch.load("net.pth")
for key,value in weight_dict.items():
    print(key,",",value.size())
# conv1.weight , torch.Size([32, 1, 3, 3])
# conv1.bias , torch.Size([32])
# conv2.weight , torch.Size([64, 32, 5, 5])
# conv2.bias , torch.Size([64])
# linear1.weight , torch.Size([32, 64])
# linear1.bias , torch.Size([32])
# linear2.weight , torch.Size([10, 32])
# linear2.bias , torch.Size([10])
# 方式1,整个网络读取权重字典
net.load_state_dict(weight_dict)
# 注意:网络的整个结构必须要一致才行。包括输入的图片通道数,输出的类别数,以及中间的层
# 方式2,给单独的一层加载权重
# 根据变量的名字进行加载
# 给conv1单独加载权重
# 在权重中包含conv1的有
# # conv1.weight , torch.Size([32, 1, 3, 3])
# # conv1.bias , torch.Size([32])
# 那么只需要只需要这两个,并且把前面的conv1去掉, 得到
# # weight , torch.Size([32, 1, 3, 3])
# # bias , torch.Size([32])
conv1_weight_dict = {}
for key,value in weight_dict.items():
    if "conv1" in key:
        new_key = key.replace("conv1.","") # 去掉前面的conv1
        conv1_weight_dict[new_key] = value
net.conv1.load_state_dict(conv1_weight_dict) #就可以进行加载
# 方式3,根据tensor的形状相同加载权重
# print(net.conv1.state_dict())
orginal_dict = net.state_dict() #当前网络的权重字典。
weight_dict = torch.load("net.pth") #读取的网络权重字典
# 通过形状相同,把orignal_dict对应的tensor 换成 weight_dict的tensor。


for key,value in orginal_dict.items():

    for key2,value2 in weight_dict.items():
        if value2.size() == value.size():
            print("形状相同")
            orginal_dict[key] = weight_dict[key2] # 将orginal换成weight_dict

net.load_state_dict(orginal_dict)










  • 6
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值