ResNet50修改网络适应灰度图片并加载预训练模型

此博文是修改https://blog.csdn.net/jiacong_wang/article/details/105631229
这位大大的博文而成的,自己根据自己的情况稍微加了点东西

要修改的地方有4处

1.修改网络第一层,把3通道改为1
法一:直接在定义网络的地方修改

self.conv1 = nn.Conv2d(1, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)

法二:在调用网络模型的地方修改

model = resnet50()
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model = model.to(device)

2.修改读取数据的方式
--------修改之前:
在这里插入图片描述
--------修改之后

 train_transformer = transforms.Compose([
        transforms.Grayscale(1), # 修改
        transforms.RandomHorizontalFlip(0.5),  
        transforms.ToTensor(),                
        transforms.Normalize(0.485, 0.229, inplace=True),  # 修改(-1,1)
    ])
  • 修改方法
  1. 修改transform(图像预处理操作)
      添加transforms.Grayscale(1),将图像转换为单通道图像(经实验,图像矩阵的数据并不会发生变化)
  2. transforms.Normalize修改如下,第一个参数为mean,第二个参数为std,因为是单通道,所以进行Z-Score时仅需要对一个通道进行操作,所以mean和std只需要一个值就行

3.修改读取数据集部分(mydataset.py)
-----修改之前
在这里插入图片描述
-----修改之后

img = Image.open(path_img)

只需要把后面转RGB的部分去掉就行

4.因为加载预训练模型而修改网络
要加载预训练模型,第一层的权重参数肯定不能加载,则只需要把第一层的权重参数避开就行:

  • 加载方法
    net = resnet50()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet50-19c8e357.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)

    # 加载预训练模型并且把不需要的层去掉
    pre_state_dict = torch.load(model_weight_path)
    print("原模型", pre_state_dict.keys())
    new_state_dict = {}
    for k, v in net.state_dict().items():          # 遍历修改模型的各个层
        print("新模型", k)
        if k in pre_state_dict.keys() and k!= 'conv1.weight':
            new_state_dict[k] = pre_state_dict[k]  # 如果原模型的层也在新模型的层里面, 那新模型就加载原先训练好的权重
    net.load_state_dict(new_state_dict, False)
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)
    net.to(device)
  • 23
    点赞
  • 92
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
预训练的ResNet-18模型的输入图像通道数应该是3,而不是1。 ResNet-18是一个经过大规模图像分类任务预训练的卷积神经网络模型。它最初在ImageNet数据集上进行了训练,该数据集的图像具有RGB通道(红、绿、蓝),因此ResNet-18模型预期的输入图像应具有3个通道。 如果你想将灰度图像(1个通道)输入ResNet-18模型进行预测,你需要将其转换为具有3个通道的伪RGB图像。可以通过将灰度图像在每个通道上复制相同的值来实现,以创建一个具有3个相同通道的图像。 下面是一个示例代码,演示了如何将灰度图像转换为伪RGB图像: ```python import torch import torch.nn as nn class GrayToRGB(nn.Module): def __init__(self): super(GrayToRGB, self).__init__() def forward(self, x): # 复制灰度图像的通道 x = torch.cat([x, x, x], dim=1) return x # 创建灰度图像 gray_image = torch.randn(1, 1, 224, 224) # 创建ResNet-18模型 resnet = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True) # 创建灰度图像到伪RGB图像的转换层 gray_to_rgb = GrayToRGB() # 将灰度图像转换为伪RGB图像 rgb_image = gray_to_rgb(gray_image) # 将伪RGB图像输入ResNet-18模型进行预测 output = resnet(rgb_image) ``` 在这个示例中,我们首先创建了一个灰度图像`gray_image`,然后加载了预训练的ResNet-18模型。接下来,我们定义了一个名为`GrayToRGB`的自定义层,用于将灰度图像转换为伪RGB图像。最后,我们通过将灰度图像传递给`GrayToRGB`层,得到具有3个通道的伪RGB图像,并将其输入ResNet-18模型进行预测。 需要注意的是,由于预训练的ResNet-18模型是在大规模分类任务上进行训练的,因此用于预测的图像应与训练时的输入图像具有相同的特征表示,即3个通道的RGB图像。将灰度图像转换为伪RGB图像只是一种近似方法,可能会对模型的性能产生一定影响。如果你希望获取更好的性能,可能需要使用其他针对灰度图像的预训练模型或自行训练模型。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值