使用pytorch预训练模型分类与特征提取

    pytorch(pytorch v0.1 这个是早期版本了)应该是深度学习框架里面比较好使用的了,相比于tensorflow,mxnet。可能在用户上稍微少一点,有的时候出问题不好找文章。下面就使用pytorch预训练模型做分类和特征提取,pytorch文档可以参考:pytorch docs  , 模型是imagenet2012训练的标签可参考:imagenet2012 labels  ,模型预测的下标按从上到下,起始(n01440764)为0

   

#encoding=utf-8

import os
import numpy as np

import torch
import torch.nn
import torchvision.models as models
from torch.autograd import Variable 
import torch.cuda
import torchvision.transforms as transforms

from PIL import Image

transform_list = [transforms.ToTensor(),
                  transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                       std =[0.229, 0.224, 0.225])]
img_to_tensor = transforms.Compose(transform_list)


def make_model():
    resmodel=models.resnet34(pretrained=True)
    resmodel.cuda()#将模型从CPU发送到GPU,如果没有GPU则删除该行
    return resmodel

#分类
def inference(resmodel,imgpath):
    resmodel.eval()#必需,否则预测结果是错误的
    
    img=Image.open(imgpath)
    img=img.resize((224,224))
    tensor=img_to_tensor(img)
    
    tensor=tensor.resize_(1,3,224,224)
    tensor=tensor.cuda()#将数据发送到GPU,数据和模型在同一个设备上运行
            
    result=resmodel(Variable(tensor))
    result_npy=result.data.cpu().numpy()#将结果传到CPU,并转换为numpy格式
    max_index=np.argmax(result_npy[0])
    
    return max_index
    
#特征提取
def extract_feature(resmodel,imgpath):
    resmodel.fc=torch.nn.LeakyReLU(0.1)
    resmodel.eval()
    
    img=Image.open(imgpath)
    img=img.resize((224,224))
    tensor=img_to_tensor(img)
    
    tensor=tensor.resize_(1,3,224,224)
    tensor=tensor.cuda()
            
    result=resmodel(Variable(tensor))
    result_npy=result.data.cpu().numpy()
    
    return result_npy[0]
    
if __name__=="__main__":
    model=make_model()
    imgpath='path_to_img/xxx.jpg'
    print inference(model,imgpath)
    print extract_feature(model, imgpath)
    

 

参考:  pytorch doc

 

 

 

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值