如何从resnet等建好的模型中删掉一部分,保留想要的

我正在研究resnet,但是不想要最后的全连接层。而是希望输出保持 2,3,224,224,这样,batch,通道,长,宽,的格式。所以找到了这个链接,里面有详细的答案:

https://discuss.pytorch.org/t/how-to-delete-layer-in-pretrained-model/17648/40

下面是我的代码:

import torch
import torch.nn as nn
import torchvision

model = torchvision.models.resnet18()
print(model)  # 打印出来看看。直接用.summary还不行呢。


model = nn.Sequential(*list(model.children())[:-2]) # !!这样可以截取其中的一部分。否则另一种方法是,如链接里的,设置一个Identity的class。用这个代替model中的,比如,layer4,或者fc层(这个名字从print的结果找到的),但是那样无法取代最后的flattern层,肯定会展平,与其再reshape回去,不如直接取子集。


import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')
model.fc = nn.Sequential()
with torch.no_grad():
    print(input_batch.shape)
    output = model(input_batch)
    print(output.shape)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
# print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# print(probabilities)

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值