Ptorch 搭建分類模型

1. 神經網絡模型

腳本可以選擇的網絡有:Mobilenet_v2, Googlenet, Inception_v3, resnet50, Densnet121。 當然也可以添加你自己的網絡。

 

2.模型訓練

a. 數據路徑下包含train和val兩個文件夾,文件夾下面存放所有類別的數據,一個類別一個文件夾。

b. resize是圖片壓縮尺寸,crop-size是圖片中心剪裁后輸入網絡的尺寸。

c. 如果要加載訓練過的模型,開啟pre,并設置model-path模型路徑。

d. 如果要使用Focal-loss,需要新建Focal-loss.py,并拷貝代碼,代碼我會在後面列出。

訓練代碼:

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
import copy
import FocalLoss
import argparse, os
import time, datetime

# import tqdm

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--net-name', dest='net_name', type=str, default='resnet50')
    parser.add_argument('--data', dest='data', type=str, default='./data/class', help='The data path!')
    parser.add_argument('--resize', dest='resize', type=int, default=300, help='img resize')
    parser.add_argument('--crop-size', dest='crop_size', type=int, default=224, help='crop resized img enter net!')
    parser.add_argument('--batch-size', dest='batch_size', type=int, default=16)
    parser.add_argument('--epochs', dest='epochs', type=int, default=100)
    parser.add_argument('--classes', dest='classes', type=int, default=3, help='class number')
    parser.add_argument('--save-path', dest='save_path', type=str, default='./model', help='save model path!')
    parser.add_argument('--pre', dest='pre_training', action='store_true', help='pre_training or not!')
    parser.add_argument('--model-path', dest='model_path', type=str, default='./model/epoch0_acc0.3676_loss1.0442.pt',
                        help='pre_training model path!')
    parser.add_argument('--focal-loss', dest='focal_loss', action='store_true', help='use focal loss!')
    parser.add_argument('--fe', dest='feature_extract', default=True,
                        help='Flag for feature extractiing, When False, wei finetune the whole mode, When True, we only update the reshaped layer paras!')
    args = parser.parse_args()
    return args


args = parse_args()

transform = {
    'train': transforms.Compose([transforms.Resize(args.resize),
                                 transforms.CenterCrop(args.crop_size),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    'val': transforms.Compose([transforms.Resize(args.resize),
                               transforms.CenterCrop(args.crop_size),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}


def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            ## 設置 requires_grad=False 凍結參數,以便在backward()中不計算梯度.
            param.requires_grad = False


def Net(feature_extract, net_name):
    net = None
    if net_name in ['Mobilenet_v2', 'mobilenet_v2']:
        net = torchvision.models.mobilenet_v2(pretrained=True)
        set_parameter_requires_grad(net, feature_extract)
        classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(net.last_channel, args.classes),
        )
        net.classifier = classifier

    elif net_name in ['googlenet', 'Googlenet']:
        net = torchvision.models.googlenet(pretrained=True)
        set_parameter_requires_grad(net, feature_extract)
        in_features = net.fc.in_features
        fc = nn.Linear(in_features, args.classes, bias=True)
        net.fc = fc

    elif net_name in ['inception_v3', 'Inception_v3']:
        net = torchvision.models.inception_v3(pretrained=True)
        set_parameter_requires_grad(net, feature_extract)
        num_ftrs = net.AuxLogits.fc.in_features
        AuxLogits = nn.Linear(num_ftrs, args.classes)
        net.AuxLogits.fc = AuxLogits
        num_ftrs = net.fc.in_features
        fc = nn.Linear(num_ftrs, args.classes)
        net.fc = fc

    elif net_name in ['resnet50', 'Resnet50']:
        net = torchvision.models.resnet50(pretrained=True)
        set_parameter_requires_grad(net, feature_extract)
        in_features = net.fc.in_features
        fc = nn.Linear(in_features, args.classes)
        net.fc = fc

    elif net_name in ['densenet121', 'Densnet121']:
        net = torchvision.models.densenet121(pretrained=True)
        set_parameter_requires_grad(net, feature_extract)
        in_features = net.classifier.in_features
        classifier = nn.Linear(in_features, args.classes)
        net.classifier = classifier

    else:
        assert net, 'please add yourself net or input right net name!'

    print(list(net.children()))
    net = net.to(device)
    return net


def train(net, data_loader, optim, criterion, exp_lr_scheduler, epochs, net_name):
    best_acc = 0
    best_model_wts = copy.deepcopy(net.state_dict())
    for epoch in range(epochs):
        print('Epoch{}/{}'.format(epoch, epochs - 1))

        for phase in ['train', 'val']:
            if phase == 'train':
                exp_lr_scheduler.step()
                net.train()
            if phase == 'val':
                net.eval()

            running_loss = 0
            running_corects = 0
            step = 1
            steps = len(data_loader[phase])
            for inputs, labels in data_loader[phase]:
                strat_time = time.time()
                inputs = inputs.to(device)
                labels = labels.to(device)

                optim.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):

                    if net_name in ['inception_v3', 'Inception_v3'] and phase == 'train':
                        outputs, aux_outputs = net(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4 * loss2
                    else:
                        outputs = net(inputs)
                        loss = criterion(outputs, labels)
                    _, predict = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optim.step()

                    end_time = time.time()
                    residual_time = str(datetime.timedelta(seconds=(steps - step) * (end_time - strat_time)))[:-7]
                    print("\r%d/%d [%s>%s] -ETA: %s - loss: %4f\n" % (
                        step, steps, '=' * int(29 * step / steps), '.' * (29 - int(29 * step / steps)), residual_time, loss), end='', flush=True)
                    step += 1

                running_loss += loss.item() * inputs.size(0)
                running_corects += torch.sum(predict == labels.data)

            epoch_loss = running_loss / data_size[phase]
            epoch_acc = running_corects.double() / data_size[phase]

            print('{} Loss:{:.4f} acc:{:.4f}'.format(phase, epoch_loss, epoch_acc))
            if epoch_acc > best_acc and phase == 'val':
                best_acc = epoch_acc
                torch.save(net.state_dict(), os.path.join(args.save_path,
                                                          'epoch{}_acc{:.4f}_loss{:.4f}.pt'.format(epoch, epoch_acc, epoch_loss)))
                best_model_wts = copy.deepcopy(net.state_dict())

    print('Best val acc', best_acc)
    net.load_state_dict(best_model_wts)
    torch.save(net.state_dict(), os.path.join(args.save_path, 'best_model_acc{:.4f}.pt'.format(best_acc)))


if __name__ == '__main__':
    # load data
    imgs_datasets = {x: torchvision.datasets.ImageFolder(os.path.join(args.data, x), transform=transform[x]) for x in ['train', 'val']}
    data_loader = {
        x: torch.utils.data.DataLoader(imgs_datasets[x], batch_size=args.batch_size, shuffle=True if x == 'train' else False, num_workers=0) for x in ['train', 'val']}
    data_size = {x: len(imgs_datasets[x]) for x in ['train', 'val']}
    img_class = imgs_datasets['train'].classes

    net = Net(args.feature_extract, args.net_name)

    # load model
    if args.pre_training:
        print('load model:{}'.format(args.model_path))
        net.load_state_dict(torch.load(args.model_path))

    # Observe that all parameters are being optimized
    if args.feature_extract:
        params_to_update = []
        for name, param in net.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
    else:
        params_to_update = net.parameters()
    optim = torch.optim.Adam(params=params_to_update, lr=0.001)
    # optim = torch.optim.SGD(params=net.parameters(), lr=0.001, momentum=0.9)

    # Loss function
    criterion = FocalLoss() if args.focal_loss else nn.CrossEntropyLoss()
    print('criterion: ', criterion)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim, step_size=7, gamma=0.1)
    train(net, data_loader, optim, criterion, exp_lr_scheduler, args.epochs, args.net_name)

Focal-loss代碼:

import torch
import torch.nn as nn
import torch.nn.functional as F


class FocalLoss(nn.Module):  # 1d and 2d

    def __init__(self, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.size_average = size_average

    def forward(self, logit, target, class_weight=None, type='softmax'):
        target = target.view(-1, 1).long()
        if type == 'sigmoid':
            if class_weight is None:
                class_weight = [1] * 2  # [0.5, 0.5]

            prob = torch.sigmoid(logit)
            prob = prob.view(-1, 1)
            prob = torch.cat((1 - prob, prob), 1)
            select = torch.FloatTensor(len(prob), 2).zero_().cpu()
            select.scatter_(1, target, 1.)

        elif type == 'softmax':
            B, C = logit.size()
            if class_weight is None:
                class_weight = [1] * C  # [1/C]*C

            # logit   = logit.permute(0, 2, 3, 1).contiguous().view(-1, C)
            prob = F.softmax(logit, 1)
            select = torch.FloatTensor(len(prob), C).zero_().cpu()
            select.scatter_(1, target, 1.)

        class_weight = torch.FloatTensor(class_weight).cpu().view(-1, 1)
        class_weight = torch.gather(class_weight, 0, target)

        prob = (prob * select).sum(1).view(-1, 1)
        prob = torch.clamp(prob, 1e-8, 1 - 1e-8)
        batch_loss = - class_weight * (torch.pow((1 - prob), self.gamma)) * prob.log()

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss

        return loss

3. 模型測試

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from collections import Counter

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def Load_Data(img_size, imgs_path):
    transform = transforms.Compose([transforms.Resize(img_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    testset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False, num_workers=0)
    return (testset, testloader)


def Load_Model(model_path):
    ## 訓練時只保存權重參數,需加載網絡
    # net = torchvision.models.mobilenet_v2(pretrained=False)
    # classifier = nn.Sequential(
    #     nn.Dropout(0.2),
    #     nn.Linear(net.last_channel, 3),
    # )
    # net.classifier = classifier
    # model = net.to(device)
    # state_dict = torch.load(model_path)
    # model.load_state_dict(state_dict)

    # 訓練時保存了整個模型
    model = torch.load(model_path)
    return model


def test(data, model):
    testset, testloader = data
    labels_list = testset.targets
    classes_list = testset.classes
    dict_right = {}
    dict_count = {}
    for i in range(len(classes_list)):
        dict_right[i] = 0
        dict_count[i] = Counter(labels_list)[i]
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs.to(device)
            labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            print('predicted:', predicted, 'labels:', labels)
            for i in range(len(labels)):
                if predicted[i] == labels[i]:
                    dict_right[int(labels[i])] += 1

    for i in range(len(classes_list)):
        acc = dict_right[i] / dict_count[i]
        print('Class {} Acc: {:.4f}'.format(classes_list[i], acc))


if __name__ == '__main__':
    model_path = r'./model/epoch0_acc0.9044_loss0.2180.pt'    # 模型路徑
    imgs_path = './data/class/val'                            # 測試集路徑
    img_size = 224                                            # 輸入網絡尺寸
    model = Load_Model(model_path)
    testdata = Load_Data(img_size=img_size, imgs_path=imgs_path)
    test(testdata, model)

4. 單張圖片預測

import os
import torch
from torchvision import transforms
from PIL import Image
import time

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
img_size = 224
model = torch.load('./model/epoch0_acc0.9044_loss0.2180.pt')

data_transforms = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def predict_image(img_path):
    start_time = time.time()
    img = Image.open(img_path)
    input = data_transforms(img).unsqueeze(0)
    input.to(device)
    output = model(input)
    _, predict = torch.max(output, 1)
    end_time = time.time()
    print('use_time', end_time - start_time)
    print('precicted classes: ', predict.numpy()[0])


while True:
    img_path = input('please input img_path:')
    if not os.path.exists(img_path) and img_path != 'q':
        print("The path error, Try again!")
        continue
    if img_path == 'q':
        break
    predict_image(img_path)

參考文章:https://blog.csdn.net/weixin_40123108/article/details/85714030

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值