Pytorch加载Resnet50预训练模型权重出现的问题

问题描述

在使用Pytorch加载包含Resnet50预训练模型的权重时应如何加载
例如,在测试一个TestNet时,该网络包含一个一维卷积、一个Resnet50和若干线性层,我们需要在初始化Resnet50层时加载其预训练权重,应该执行以下操作来避免调用权重出错。
Resnet50架构:
Resnet

解决方案

1.定义model_urls来指定预训练权重的下载地址

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

2.初始化resnet50后,通过以下方式加载权重文件

 res50_o = models.resnet50(pretrained=pretrained)
        if pretrained:
            print('Loading pretrained model')
            # Load pretrained weights using load_state_dict
            pretrained_weights = model_zoo.load_url(model_urls['resnet50'])
            res50_o.load_state_dict(pretrained_weights)
            print('Loaded Complete!')

3.TestNet完整代码如下:

import torch
import torch.nn as nn
import torchvision.models as models

import torch.utils.model_zoo as model_zoo

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

class TestNet(nn.Module):
    def __init__(self, input_shape,channel_shape,pretrained):
        super().__init__()
        self.channel_shape = channel_shape
        self.fc = nn.Linear(input_shape, 128)

        # Define a 1x1 convolution to adjust the number of channels to 3
        self.conv1x1 = nn.Conv2d(in_channels=channel_shape, out_channels=3, kernel_size=1)
        # Initialize ResNet50 model without loading pretrained weights
        res50_o = models.resnet50(pretrained=pretrained)
        if pretrained:
            print('Loading pretrained model')
            # Load pretrained weights using load_state_dict
            pretrained_weights = model_zoo.load_url(model_urls['resnet50'])
            res50_o.load_state_dict(pretrained_weights)
            print('Loaded Complete!')
        # Remove the fully connected layer from ResNet50
        self.res50 = nn.Sequential(*list(res50_o.children())[:-1])

        self.fc_a = nn.Linear(2048, 1024)  # ResNet50 output features are 512
        self.fc_b = nn.Linear(1024, 512)
        self.fc_c = nn.Linear(512, 128)
        self.fc_d = nn.Linear(128, 64)
        self.fc_e = nn.Linear(64, 32)
        self.fc_class = nn.Linear(32, 2)

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), self.channel_shape, 1, 128)  # Reshape to (batch_size, channels, height, width)
        # x = x.repeat(1, 3, 1, 1)
        # 计算第二个维度的平均值
        # mean_values = x.mean(dim=1, keepdim=True)  # 结果形状为 (batch, 1, 1, 128)
        # # 将原张量和计算的平均值沿第二个维度拼接
        # x = torch.cat((x, mean_values), dim=1)
        x = self.conv1x1(x)  # Apply 1x1 convolution to get 3 channels
        x = self.res50(x)  # Forward pass through ResNet50
        x = x.view(x.size(0), -1)  # Flatten the output of ResNet50
        x = self.fc_a(x)
        x = self.fc_b(x)
        x = self.fc_c(x)
        x = self.fc_d(x)
        x = self.fc_e(x)
        x = self.fc_class(x) # Final classification layer
        return x
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值