问题描述
在使用Pytorch加载包含Resnet50预训练模型的权重时应如何加载
例如,在测试一个TestNet时,该网络包含一个一维卷积、一个Resnet50和若干线性层,我们需要在初始化Resnet50层时加载其预训练权重,应该执行以下操作来避免调用权重出错。
Resnet50架构:
解决方案
1.定义model_urls来指定预训练权重的下载地址
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
2.初始化resnet50后,通过以下方式加载权重文件
res50_o = models.resnet50(pretrained=pretrained)
if pretrained:
print('Loading pretrained model')
# Load pretrained weights using load_state_dict
pretrained_weights = model_zoo.load_url(model_urls['resnet50'])
res50_o.load_state_dict(pretrained_weights)
print('Loaded Complete!')
3.TestNet完整代码如下:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.utils.model_zoo as model_zoo
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
class TestNet(nn.Module):
def __init__(self, input_shape,channel_shape,pretrained):
super().__init__()
self.channel_shape = channel_shape
self.fc = nn.Linear(input_shape, 128)
# Define a 1x1 convolution to adjust the number of channels to 3
self.conv1x1 = nn.Conv2d(in_channels=channel_shape, out_channels=3, kernel_size=1)
# Initialize ResNet50 model without loading pretrained weights
res50_o = models.resnet50(pretrained=pretrained)
if pretrained:
print('Loading pretrained model')
# Load pretrained weights using load_state_dict
pretrained_weights = model_zoo.load_url(model_urls['resnet50'])
res50_o.load_state_dict(pretrained_weights)
print('Loaded Complete!')
# Remove the fully connected layer from ResNet50
self.res50 = nn.Sequential(*list(res50_o.children())[:-1])
self.fc_a = nn.Linear(2048, 1024) # ResNet50 output features are 512
self.fc_b = nn.Linear(1024, 512)
self.fc_c = nn.Linear(512, 128)
self.fc_d = nn.Linear(128, 64)
self.fc_e = nn.Linear(64, 32)
self.fc_class = nn.Linear(32, 2)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), self.channel_shape, 1, 128) # Reshape to (batch_size, channels, height, width)
# x = x.repeat(1, 3, 1, 1)
# 计算第二个维度的平均值
# mean_values = x.mean(dim=1, keepdim=True) # 结果形状为 (batch, 1, 1, 128)
# # 将原张量和计算的平均值沿第二个维度拼接
# x = torch.cat((x, mean_values), dim=1)
x = self.conv1x1(x) # Apply 1x1 convolution to get 3 channels
x = self.res50(x) # Forward pass through ResNet50
x = x.view(x.size(0), -1) # Flatten the output of ResNet50
x = self.fc_a(x)
x = self.fc_b(x)
x = self.fc_c(x)
x = self.fc_d(x)
x = self.fc_e(x)
x = self.fc_class(x) # Final classification layer
return x