MobileNet V2基于pytorch的多标签训练分类

        在分类的过程中有时会涉及一些多标签的分类情况,下面将从数据准备、数据载入、如何训练以及如何测试等四个方面进行说明。

1.什么是多标签分类

        我们常见的二分类或者多分类基本都是基于单标签的,即一张图片只能是一个标签。多标签的意思是,比如下面的图既可以是书包、也可以是人或者行李箱等多个标签。如果单标签分类就只能是这三个中的一个,而多标签就不会出现这个情况。

2.数据集准备

         现在假设我们要构建一个书包、行李箱加上负样本的三分类的多标签分类器,我们设置三个单独的标签分别为1,0,0;0,1,0和0,0,1,分别对应的标签为负样本、书包和行李箱。特别的,如果一个人既背了书包又拉了行李箱,则其标签应该为0,1,1,其他类似。

          如果我们有负样本、行李箱、背书包和既拉行李箱又背书包四个标签,先建立Train的文件夹,然后在里面分别创建在四个文件夹,且对应的文件夹命名为:Normal、Suitcase、Backpack以及Backpack_Suitcase。然后将下面代码放在Train文件夹坪运行。运行完成之后将里面的内容复制到一个新的命名为Train.txt的文件中,后面要用。

#coding:utf8
import os;
import csv
your_need_process_type=[".jpg",".png"]

def GetImgNameByEveryDir(file_dir,videoProperty):  
    # Input   Root Dir and get all img in per Dir.
    # Out     Every img with its filename and its dir and its path  
    FileNameWithPath = [] 
    FileName         = []
    FileDir          = []
    # videoProperty=['.png','jpg','bmp']
    for root, dirs, files in os.walk(file_dir):  
        for file in files:  
            if os.path.splitext(file)[1] in videoProperty:  
                # print('root = {},dirs = {},file = {}'.format(root,dirs,file))
                FileNameWithPath.append(os.path.join(root, file))  # 保存图片路径
                FileName.append(file)                              # 保存图片名称
                FileDir.append(os.path.join(root, '/'))            # 保存图片所在文件夹
    return FileName,FileNameWithPath,FileDir

file_dir = './DataDir/' # 我们的数据存放文件夹
videoProperty = ['.jpg','.jpeg','.bmp']
FileName,FileNameWithPath,FileDir = GetImgNameByEveryDir(file_dir,videoProperty)
csvName = 'filelist' + '.csv'

TrainDir = './Train/'   # 我们训练的时候数据最终要存放的文件夹

DataDir = ['Normal','Backpack','Suitcase','Backpack_Suitcase']
# 请注意,四种情况的标签图片放在各自的文件夹中
with open(csvName,"w") as datacsv:
    csvwriter = csv.writer(datacsv,dialect=("excel"))
    for k in range(len(FileName)):
        if DataDir[0] in FileDir[k]:
            csvwriter.writerow([TrainDir+FileName[k],1,0,0])
        if DataDir[1] in FileDir[k]:
            csvwriter.writerow([TrainDir+FileName[k],1,0,0])
        if DataDir[2] in FileDir[k]:
            csvwriter.writerow([TrainDir+FileName[k],0,0,1])
        if DataDir[3] in FileDir[k]:
            csvwriter.writerow([TrainDir+FileName[k],0,1,1])

3.如何训练

         请首先建立MobileNet v2文件夹。建立Train的文件夹,将上面所说的数据全部放入其中,并将Train.txt的文件夹也放在当前根目录。

3.1读取数据模块

    函数名为:ReadData.py。代码如下:

from __future__ import print_function
import argparse 
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from PIL import Image
import time
import matplotlib.pyplot as plt
import os
import cv2





def GetImgNameByEveryDir(file_dir):  
    # Input   Root Dir and get all img in per Dir.
    # Out     Every img with its filename and its dir and its path  
    FileNameWithPath = [] 
    FileName         = []
    FileDir          = []
    videoProperty=['.png','jpg','bmp']
    for root, dirs, files in os.walk(file_dir):  
        for file in files:  
            if os.path.splitext(file)[1] in videoProperty:  
                FileNameWithPath.append(os.path.join(root, file))  # 保存图片路径
                FileName.append(file)                              # 保存图片名称
                FileDir.append(root[len(file_dir):])               # 保存图片所在文件夹
    return FileName,FileNameWithPath,FileDir



def default_loader(path):
    im = Image.open(path).convert('RGB')
    # im_resize = im.resize((128,128))
    return im

class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        imgnum=0
        for line in fh:
            line = line.rstrip()
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            # imgs.append((words[0],[int(words[1]),int(words[2]),int(words[3])]))

            # 多标签分类pytorch好像不能使用one-hot形式的标签,必须输入对应的十进制数,其内部会自动转为one-hot
            # label_ = 4*int(words[1]) + 2*int(words[2]) + int(words[3]) - 1
            print('words = ',words,len(words))
            if len(words) ==6:
                label_   = np.array([int(words[3]),int(words[4]),int(words[5])])#有复制过文件导致文件名有'副本'等字样会出现错误,此代码可处理
                imgname = words[0] + ' ' + words[1] + ' ' + words[2]
            # if len(words) ==7:
            #     label_   = np.array([int(words[4]),int(words[5]),int(words[6])])#有复制过文件导致文件名有'副本'等字样会出现错误,此代码可处理
            #     imgname = words[0] + ' ' + words[1] + ' ' + words[2] + words[3]
            if len(words) ==4:
                label_ = np.array([int(words[1]),int(words[2]),int(words[3])])
                imgname= words[0]
            print(imgname)
            if os.path.exists(imgname)==True:
                imgnum+=1
                imgs.append((imgname,label_))
            # print('label_tmp = ',label_tmp)
            # label_ = np.array([int(words[1]) + int(words[2]) + int(words[3])])
        print('imgnum = ',imgnum)
        # raise
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label

    def __len__(self):
        return len(self.imgs)

def GetData(batch_size,isTrain,num_workers,rootDir):
    img_size = 224
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((img_size,img_size),3),
            #transforms.RandomCrop(img_size),
            transforms.RandomHorizontalFlip(),#随机水平翻转
            transforms.RandomRotation(10),    #随机旋转45度
            transforms.ColorJitter(brightness=1,contrast=1,hue=0.5),# 随机从 0 ~ 2 之间亮度变化,1 表示原图
            transforms.RandomGrayscale(p=0.5),    # 以0.5的概率进行灰度化
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize((img_size,img_size),3),
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    mean = [ 0.485, 0.456, 0.406 ]
    std  = [ 0.229, 0.224, 0.225 ]
    transform = transforms.Compose([transforms.Resize((img_size,img_size)),transforms.ToTensor(), transforms.Normalize(mean = mean, std = std),])
    

    # rootDir = './'
    train_data = MyDataset(txt=rootDir+'Train.txt', transform=transform)   #或者data_transforms
    test_data  = MyDataset(txt=rootDir+'Test.txt', transform=transform)

    # print("train_data = ",len(train_data),"  ",type(train_data))
    # train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
    
    image_datasets = {}
    image_datasets['train'] = train_data
    image_datasets['val']   = test_data

    dataloders = {}
    if isTrain:
        dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=num_workers) for x in ['train', 'val']}
    else:
        dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=num_workers) for x in [ 'val']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

    print('dataset_sizes = ',dataset_sizes)
    return image_datasets, dataset_sizes,dataloders
        
# if __name__ == '__main__':
#     batch_size   = 64
#     IsTrain      = 1
#     num_workers  = 0
#     GetData(batch_size,IsTrain,num_workers)

3.2 多标分类网络

   文件名为:mobilenet_v2.py。代码如下:

import torch.nn as nn
import math
import torch
from torch import nn
from torch.nn import init
from torchvision import models

def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


def weights_init_kaiming(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
        init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm1d') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)

class ClassBlock(nn.Module):
    def __init__(self, input_dim, num_bottleneck = 512):
        super(ClassBlock, self).__init__()


        add_block = []
        add_block += [nn.Linear(input_dim, num_bottleneck)]
        add_block += [nn.BatchNorm1d(num_bottleneck)]
        add_block += [nn.LeakyReLU(0.1)]
        add_block += [nn.Dropout(p=0.5)]
        add_block += [nn.Linear(num_bottleneck, 1)]
        add_block += [nn.Sigmoid()]

        add_block = nn.Sequential(*add_block)
        add_block.apply(weights_init_kaiming)

        self.classifier = add_block

    def forward(self, x):
        x = self.classifier(x)
        return x

# n_class = 3
class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        assert input_size % 32 == 0
        self.n_class = n_class
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)
        self.avgpool  = nn.AvgPool2d(kernel_size=7,stride=1)# 7 for 224, 4 for 128

        for c in range(n_class):
            self.__setattr__('class_%d' % c, ClassBlock(self.last_channel, 512) )



    def forward(self, x):
        x = self.features(x)
        # print("the x shape :{}".format(x.shape))#(1,1280,4,4)
        # x = x.mean(3).mean(2)#在第3,2个维度求平均值,ncnn不支持mean ,可换成avgpool
        x = self.avgpool(x)
        x = x.view(-1,1280)
        for c in range(self.n_class):
            if c == 0:
                pred = self.__getattr__('class_%d' % c)(x)
            else:
                pred = torch.cat((pred, self.__getattr__('class_%d' % c)(x) ), dim=1)
                # print(self.__getattr__('class_%d' % c)(x).size())
            # print('c = {},pred = {}'.format(c,pred.size()))
        return pred

3.3 训练代码

    文件名为:train.py。代码如下:

from __future__ import print_function, division

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import time
import os
import argparse
import sys
from torchvision import models
from mobilenet_v2 import MobileNetV2
from ReadData import GetData
# python train.py --train_data_dir /media/ubuntu_data/head_60_train_val/train --val_data_dir /media/ubuntu_data/head_60_train_val/val --resume ./mobilenet_v2_pretrain/mobilenet_v2.pth.tar --save_path ./weights_60/
import numpy as np
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from torchviz import make_dot


loss_for_show = []

def to_one_hot(x, C=2, tensor_class=torch.FloatTensor):
    """ One-hot a batched tensor of shape (B, ...) into (B, C, ...) """
    print('x = ',x.size(0))
    print('x = ',x)

    x_one_hot = tensor_class(x.size(0), C, *x.shape[1:]).zero_()
    x_one_hot = x_one_hot.scatter_(1, x.unsqueeze(1), 1)
    return x_one_hot



def train_model(args, model, criterion, optimizer, scheduler, num_epochs,dataloders, dataset_sizes):
    since = time.time()

    best_model_wts = model.state_dict()
    best_acc=0.0
    best_epoch=0

    if os.path.exists(args.save_path)==False:
        os.makedirs(args.save_path)
    for epoch in range(args.start_epoch+1,args.start_epoch+1+num_epochs):
        print('Epoch {}/{}'.format(epoch,num_epochs+args.start_epoch))
        print('-'*50)
        # Each epoch has a training and validation phase  
        for phase in ['train','val']:
            if phase == 'train':
                scheduler.step()
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            tic_batch = time.time()
            for i, data in enumerate(dataloders[phase]):
                images, labels, = data
                if use_gpu:
                    images = Variable(images.cuda())
                    # Labels = []
                    # for k in range(len(labels)):
                    #     Labels.append(labels[k].numpy())
                    labels = torch.Tensor(labels.float())
                    labels = labels.cuda()
                    # print('labels = ',labels)
                images = images
                optimizer.zero_grad()
                outputs = model(images)
                # raise

                # print('labels = ',labels)
                # print('outputs = ',outputs)
                # raise

                # one_hot_labels = to_one_hot(labels, C=args.num_class)

                

                g = make_dot(outputs)
                g.render('espnet_model', view=False)  #这两行代码主要用于模型可视化
                # raise

                m = nn.Sigmoid()


                # print('outputs = ',m(outputs))
                # print('labels = ',labels)


                label_loss = criterion(outputs, labels)


                loss_for_show.append(label_loss.data)
                fig=plt.figure(1)
                plt.plot(np.array(loss_for_show))
                plt.savefig(args.save_path + 'loss.png')




                if phase == 'train':
                    label_loss.backward()
                    optimizer.step()
                preds = torch.gt(outputs, torch.ones_like(outputs)/2 ).data
                running_loss += label_loss.item()
                # running_corrects += torch.sum(preds == labels).item() / args.num_class.Byte()
                print('[Epoch {}/{}]-[batch:{}/{}] lr:{:.7f} {} Loss: {:.6f} '.format(epoch, args.start_epoch+num_epochs, i, round(dataset_sizes[phase]/args.batch_size)-1, scheduler.get_lr()[0], phase, label_loss.item()))


            epoch_loss = running_loss / dataset_sizes[phase]
            # epoch_acc = running_corrects / dataset_sizes[phase]

            print('phase: {} epoch: {} Loss: {:.4f}'.format(phase, epoch , epoch_loss))

            # 保存最好的模型
            # if phase=="val":
            #     best_model_wts=model.module.state_dict()
            #     best_epoch=epoch
            #     print("best_acc:{} best_epoch:{}".format(best_acc,best_epoch))
        #每训练完一轮保存一次结果
        if (epoch+1) % args.save_epoch_freq == 0:
            if not os.path.exists(args.save_path):
                os.makedirs(args.save_path)
            torch.save(model, os.path.join(args.save_path, "epoch_" + str(epoch) + ".pth.tar"))
            # torch.save({"epoch":epoch,
            #             "model_state_dict":model.module.state_dict(),
            #             "optimizer":optimizer.state_dict()
            #             }, os.path.join(args.save_path, "checkpoints_epoch_" + str(epoch) + ".tar"))

    # torch.save({"epoch":best_epoch,
    #             "model_state_dict":best_model_wts,
    #             'epoch_acc': best_acc,
    #             "optimizer":optimizer.state_dict()
    #                     }, os.path.join(args.save_path, "best_epoch_" + str(best_epoch) + ".tar"))
    # load best model weights获得权重,
    # model.load_state_dict(best_model_wts)
    return model


if __name__ == '__main__':
    '''
    If you want to train from scratch, you can run as follows:
    python train.py --batch-size 256 --gpus 0,1,2,3
    If you want to train from one checkpoint, you can run as follows(for example train from epoch_4.pth.tar, the --start-epoch parameter is corresponding to the epoch of the checkpoint):
    python train.py --batch-size 256 --gpus 0,1,2,3 --resume output/epoch_4.pth.tar --start_epoch 4
    '''
    print(torch.__version__)
    parser = argparse.ArgumentParser(description="PyTorch implementation of SENet")
    parser.add_argument('--train_data_dir', type=str, default="/ImageNet")
    parser.add_argument('--val_data_dir', type=str, default="/ImageNet")
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_class', type=int, default=3)
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--num_workers', type=int, default=0)
    parser.add_argument('--gpus', type=str, default='0')
    parser.add_argument('--print_freq', type=int, default=50)
    parser.add_argument('--save_epoch_freq', type=int, default=1)
    parser.add_argument('--save_path', type=str, default="three_class_model")
    parser.add_argument('--resume', type=str, default="", help="For training from one checkpoint")
    parser.add_argument('--start_epoch', type=int, default=0, help="Corresponding to the epoch of resume ")
    parser.add_argument('--load_path', type=str, default="", help="For training from one model_file")
    args = parser.parse_args()


    #加载模型
    # if args.load_path:
    #     model=torch.load(args.load_path)
    # else:
    model=MobileNetV2(n_class=args.num_class)
    model_dict =  model.state_dict()
    if args.resume:
        if os.path.isfile(args.resume):
            print(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            # for param_tensor in checkpoint:
            #     print(param_tensor,"\t",checkpoint[param_tensor].size())
            # args.start_epoch=checkpoint['epoch']
            print("the start epoch is {}".format(args.start_epoch))
            state_dict = {k: v for k, v in checkpoint.items() if k in model_dict.keys()}
            # for param_tensor in state_dict:
            #     print(param_tensor,"\t",state_dict[param_tensor].size())
            model_dict.update(state_dict)
            model.load_state_dict(model_dict)
            # for it in state_dict
            # model.load_state_dict(base_dict)
        else:
            print(("=> no checkpoint found at '{}'".format(args.resume)))


    image_datasets, dataset_sizes,dataloders= GetData(args.batch_size,True,args.num_workers,args.train_data_dir)
    images,labels= next(iter(dataloders['train']))
    # use gpu or not
    use_gpu = torch.cuda.is_available()
    print("use_gpu:{}".format(use_gpu))

    if use_gpu:
        model = model.cuda()
        model = torch.nn.DataParallel(model, device_ids=[int(i) for i in args.gpus.strip().split(',')])


    criterion = nn.BCELoss()
    optimizer_ft = torch.optim.SGD(model.parameters(), lr = 0.001, momentum = 0.9,weight_decay = 5e-4, nesterov = True,)
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size = 20, gamma = 0.1,)


    # define loss function
    # criterion = nn.CrossEntropyLoss()
    # optimizer_ft = optim.SGD(model.parameters(), lr=args.lr)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.9) #每过1个epoch训练,学习率就乘gamma

    model = train_model(args=args,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer_ft,
                           scheduler=exp_lr_scheduler,
                           num_epochs=args.num_epochs,
                           dataloders=dataloders,
                           dataset_sizes=dataset_sizes)

4. 测试代码

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import time
import os
import argparse
import shutil
from read_SmileData import SmileData as ImageNetData
import ShuffleNetV2
from mobilenet_v2 import MobileNetV2
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt # plt 用于显示图片
import matplotlib.image as mpimg # mpimg 用于读取图片
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import cv2
import torch.nn.functional as F
import numpy as np
import traceback
import torchvision
import sys

img_size=224
labels = ['Normal','Package','Luggage']
data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((img_size,img_size),3),
            #transforms.RandomCrop(img_size),
            transforms.RandomHorizontalFlip(),#随机水平翻转
            transforms.RandomRotation(10),    #随机旋转45度
            transforms.ColorJitter(brightness=1,contrast=1,hue=0.5),# 随机从 0 ~ 2 之间亮度变化,1 表示原图
            transforms.RandomGrayscale(p=0.5),    # 以0.5的概率进行灰度化
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            # transforms.Resize((img_size,img_size),3),
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
 }


def GetImgNameByEveryDir(file_dir,videoProperty):  
    # Input   Root Dir and get all img in per Dir.
    # Out     Every img with its filename and its dir and its path  
    FileNameWithPath = [] 
    FileName         = []
    FileDir          = []
    # videoProperty=['.png','jpg','bmp']
    for root, dirs, files in os.walk(file_dir):  
        for file in files:  
            if os.path.splitext(file)[1] in videoProperty:  
                FileNameWithPath.append(os.path.join(root, file))  # 保存图片路径
                FileName.append(file)                              # 保存图片名称
                FileDir.append(root[len(file_dir):])               # 保存图片所在文件夹
    return FileName,FileNameWithPath,FileDir

def TestImgDir(args,model,use_gpu):
    FileName,FileNameWithPath,FileDir = GetImgNameByEveryDir(args.data_dir,['.jpg'])

    with torch.no_grad():
        for k in range(len(FileName)):
            image  = cv2.imread(FileNameWithPath[k])
            img    = Image.open(FileNameWithPath[k]).convert('RGB')
            img    = torchvision.transforms.functional.resize(img, (img_size, img_size))
            img    = torchvision.transforms.functional.to_tensor(img).unsqueeze(0).numpy()
            inputs = torch.tensor(img.copy(), dtype=torch.float32)
            # inputs = inputs.unsqueeze(0)
            if use_gpu:
                inputs = Variable(inputs.cuda())

            outputs     = model(inputs)
            log_softmax = F.softmax(outputs,dim=1)
            # print(log_softmax)
            preds=log_softmax[0].tolist()
            print('FileName = ',FileName[k],' preds = ',labels[preds.index(max(preds))],' ',preds)

            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(image,labels[preds.index(max(preds))],(0,int(image.shape[0]/2)),font,1.2,(0,255,255),1)
            cv2.imshow(' ',image)
            cv2.waitKey(0)



def img_to_tensor(img_full_path):
    data = Image.open(img_full_path)
    #data.show()
    data = data_transforms['val'](data)
    return data

def ModelInitial(args):
    # 是否使用gpu
    use_gpu = torch.cuda.is_available()
    print("use_gpu:{}".format(use_gpu))


    model=MobileNetV2(n_class=args.num_class)
    model_dict=model.state_dict()
    print(os.path.exists(args.load_path))
    if os.path.exists(args.load_path):
        state_dict=torch.load(args.load_path)

        pretained_dict = state_dict.module.state_dict()
        # pretained_dict =state_dict["model_state_dict"] 
        model.load_state_dict(pretained_dict)
        # print(state_dict['epoch'])
        # print(state_dict['epoch_acc'])
        # params=model.state_dict() 
        # params=state_dict["model_state_dict"] 
        # model.load_state_dict(params)
        print("load cls model successfully")
        model.eval()
        if use_gpu:
            model = model.cuda()
            model = torch.nn.DataParallel(model, device_ids=[int(i) for i in args.gpus.strip().split(',')])

        return model,use_gpu
    else:
        print("the path:{} is not a file".format(args.load_path))
        sys.exit()
    

    


if __name__ == '__main__':
    
    '''
    python test.py --batch_size 1  --data_dir /media/ubuntu_data2/02_dataset/chenqy/test/test_fir/ --load_path /home/lgx/liguanxi/ShuffleNet_V2_pytorch_caffe-master/output_lgx/epoch_39.pth.tar
    '''
    parser = argparse.ArgumentParser(description="PyTorch implementation of SENet")
    parser.add_argument('--data_dir', type=str, default="/ImageNet")
    parser.add_argument('--gpus', type=str, default='0')
    parser.add_argument('--resume', type=str, default="", help="For training from one checkpoint")
    parser.add_argument('--load_path', type=str, default="", help="For training from one checkpoint")
    parser.add_argument('--num_class', type=int, default=2)
    args = parser.parse_args()
    # 加载分类模型
    print("start loading cls model")
    model,use_gpu = ModelInitial(args)
    TestImgDir(args,model,use_gpu)

        

注:此代码为基于pytorch的Mobilnet v2预训练模型的多标签分类。请先下载相应的ImageNet预训练模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值