pytorch修改resnet18 输入通道

方法一:扩张1通道为3通道,利用torch.expand()方法

    model = resnet18(pretrained=False) # 主干提取网络
    model.load_state_dict(torch.load('./resnet18-5c106cde.pth'), strict=False)
    print(model)
    par = summary(model, (3, 224, 224), device='cpu')
    print(par)

    net = RFNet( model, 1, use_bn=True) # 输出类别 num_classes
    # print(model)
    input1 = torch.rand((1,3,256,256)) # 输入通道为1
    input1 = input1.expand(1,3,256,256) # 扩展为3通道

    print(input1.shape)
    input2 = torch.rand((1,1,256, 256))
    output = net(input1,input2)
    print(output.shape)

方法二:修改字典参数

import torchvision.models as models
import torch
import torch.nn as nn
from torchsummary import summary

resnet18 = models.resnet18(pretrained=False)
resnet18.conv1= nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False)

# print(resnet18)
pretrained_dict = torch.load('./resnet/resnet18-5c106cde.pth')
# for k, v in pretrained_dict.items():
#     print(k)

x = torch.rand(64, 1, 7, 7)
pretrained_dict["conv1.weight"] = x

conv1 = pretrained_dict["conv1.weight"]
print(conv1.shape)
resnet18.load_state_dict(pretrained_dict)

# print(resnet18)
par = summary(resnet18, (1, 224, 224),device='cpu')
print(par)

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值