基于Resnet50的pytorch框架下的图像特征提取

在Resnet50CNN结构下实现图像的特征提取,这里采用的是CV2的图像读入方式,最后再把得到的图像转换成npy格式进行输出得得,图像对应的特征。

# -*- coding: utf-8 -*-
"""
Function: 图像特征的提取,可以依据需求修改CNN的输出,得到不同层网络的输出图像特征
Writer: Zenght
date:2019.2.16

"""
from __future__ import print_function, division, absolute_import
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision import datasets, models, transforms
import os
import cv2
import time
import copy
import torch.utils.data as data
from Rsenet50 import Resnet


class Net(nn.Module):
    #         此处可以添加自行设定的网络结构
    def __init__(self):
        super(Net, self).__init__()



def cv2_imageloader(path):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    img = cv2.imread(path)
    img = cv2.resize(img, (224, 224))
    im_arr = np.float32(img)
    im_arr = np.ascontiguousarray(im_arr[..., ::-1])
    im_arr = im_arr.transpose(2, 0, 1)# Convert Img from BGR to RGB


    for channel, _ in enumerate(im_arr):
        # Normalization
        im_arr[channel] /= 255
        im_arr[channel] -= mean[channel]
        im_arr[channel] /= std[channel]

    # Convert to float tensor
    im_as_ten = torch.from_numpy(im_arr).float()
    # Convert to Pytorch variable
    im_as_var = Variable(im_as_ten, requires_grad=True)

    return im_as_var


def default_loader(path):

    return cv2_imageloader(path)

class CustomImageLoader(data.Dataset):
    ##自定义类型数据输入
    def __init__(self, img_path, txt_path, dataset = '', loader = default_loader, save_path='/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/Feature/MIT67'):
        im_list = []
        im_dirs = []
        im_labels = []
        with open(txt_path, 'r') as files:
            for line in files:
                items = line.split()

                if items[0][0] == '/':
                    imname = line.split()[0][1:]
                    fnewname = '_'.join(imname[:-4].split('/')) + '.npy'
                else:
                    imname = line.split()[0]
                    fnewname = '_'.join(imname[:-4].split('/'))+'.npy'
                im_list.append(os.path.join(img_path, imname))
                im_labels.append(int(items[1]))
                im_dirs.append(os.path.join(save_path, fnewname))
        self.imgs = im_list
        self.labels = im_labels
        self.save_dir = im_dirs
        self.loader = loader
        self.dataset = dataset

    def __len__(self):

        return len(self.imgs)

    def __getitem__(self, item):
        # print(item)
        img_name = self.imgs[item]
        label = self.labels[item]
        imdir = self.save_dir[item]
        img = self.loader(img_name)

        return img, label, imdir


batch_size = 64
device = torch.device('cuda:0')

# SUN397 INPUT
# image_dir = '/media/haitaizeng/000222840009D764/Images'#
# image_datasets = {x : CustomImageLoader(image_dir, txt_path=('/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/Trainfiles/SUN397/'+x+'Images.label'),
#
#                                         dataset=x) for x in ['Train', 'Test']
#                   }

#MIT67 INPUT
image_dir = '/media/haitaizeng/00038FCE000387A5/cgw/Datasets/MIT67/Images'
image_datasets = {x : CustomImageLoader(image_dir, txt_path=('/home/haitaizeng/stanforf/alex_mit/data_image/'+x+'Images.label'),
                                        dataset=x) for x in ['Train', 'Test']
                  }

dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=batch_size,
                                                 shuffle=True) for x in ['Train', 'Test']}


dataset_sizes = {x: len(image_datasets[x]) for x in ['Train', 'Test']}

def Feature_extractor(models, savepath):
    for phase in ['Train', 'Test']:
        for images, labels,save_dir in dataloders[phase]:
            images.to(device)
            labels.to(device)
            # 输出特征,并转换为NPY格式进行保存
            output3 = models(images.cuda())
            output = nn.functional.softmax(output3, dim=0)
            print(output.shape)
            output = output.cpu()
            output = torch.squeeze(output)
            output = output.data.numpy()

            for feat, featpath in zip(output, save_dir):
                np.save(featpath, feat)



if __name__ == '__main__':
    Num_class = 67
    pthpath = '/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/save_model/MIT67/Places0.8500.pth'
    # model_ft = net()  ##这是自行编写的Resnet50,用于后面的特征提取的操作
    model_ft = Resnet([3, 4, 6, 3], Num_class)
    ckpt = torch.load(pthpath, map_location=lambda storage, loc: storage)

    model_ft.load_state_dict(ckpt)
    model_ft.eval()
    model_ft = model_ft.to(device)
    model_ft.cuda()
    path = '/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/Feature/MIT67'
    Feature_extractor(models=model_ft, savepath=path)

 

  • 13
    点赞
  • 110
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值