深度学习系列笔记——伍(基于pytorch1.5的猫狗大战训练)

先贴代码,后续找时间把代码解释下,并给出相应的测试代码

这次用的是resnet50在ImageNet1000上的预训练模型

# coding=utf-8
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset
from torchvision import transforms, datasets
from matplotlib import pyplot as plt
from tqdm import tqdm


def plot_train_history(epochs_history, train_acc_history, val_acc_history, train_loss_history, val_loss_history):
# 自写的绘图函数,可复用于其他的训练过程
    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_history, train_acc_history, label='Training Accuracy')
    plt.plot(epochs_history, val_acc_history, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_history, train_loss_history, label='Training Loss')
    plt.plot(epochs_history, val_loss_history, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()


def train():
    epochs = 10  # 训练次数
    batch_size = 64  # 批处理大小
    num_workers = 0  # 多线程的数目
    use_gpu = torch.cuda.is_available()
    PATH = 'resnet50.pt'
    # 对加载的图像作归一化处理, 并裁剪为[224x224x3]大小的图像
    data_transform = transforms.Compose([
        transforms.Resize(256),  # 重置图像分辨率
        transforms.CenterCrop(224),  # 中心裁剪
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
    ])

    train_dataset = datasets.ImageFolder(root='\catVSdog\\data\\train',
                                         transform=data_transform)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers)

# 设置数据集加载器
    test_dataset = datasets.ImageFolder(root='catVSdog\\data\\validation', transform=data_transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True,
                                              num_workers=num_workers)

    resnet50 = torchvision.models.resnet50(pretrained=True)
    resnet50.fc = nn.Linear(2048, 2)

    ct = 0
    # 冻结参数
    for child in resnet50.children():
        if ct < 9:
            for param in child.parameters():
                param.requires_grad = False
            ct += 1
    if (os.path.exists(PATH)):
        print("Previous model existed, loading model...")
        resnet50 = torch.load(PATH)

    if use_gpu:
        print('gpu is available')
        resnet50 = resnet50.cuda()
    else:
        print('gpu is unavailable')

    print(resnet50)
    train_loss_history = []
    train_acc_history = []
    val_loss_history = []
    val_acc_history = []
    x = np.arange(1, epochs + 1)
    # 定义loss和optimizer
    cirterion = nn.CrossEntropyLoss()
    # optimizer = optim.SGD(resnet50.parameters(), lr=0.02, momentum=0.9)
    optimizer = optim.Adam(resnet50.parameters(), lr=0.05)

    max_accuracy = 0.7
    for epoch in tqdm(range(epochs)):
        running_loss = 0.0
        train_correct = 0
        train_total = 0
        current_train_total = 0
        for step, data in enumerate(tqdm(train_loader), 0):  # 第二个参数表示指定索引从0开始
            current_train_total += batch_size
            inputs, train_labels = data
            if use_gpu:
                inputs, labels = Variable(inputs.cuda()), Variable(train_labels.cuda())
            else:
                inputs, labels = Variable(inputs), Variable(train_labels)
            optimizer.zero_grad()
            outputs = resnet50(inputs)
            _, train_predicted = torch.max(outputs.data, 1)  # 返回每一行最大值的数值和索引,索引对应分类
            train_correct += (train_predicted == labels.data).sum()

            loss = cirterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_total += train_labels.size(0)
            print("current train accuracy: %.3f" % (train_correct.item() / current_train_total),
                  "running loss:", loss.item(), end="     ")

        print("-" * 10, "train result", "-" * 10)
        print('train %d epoch loss: %.3f  acc: %.3f ' % (
            epoch + 1, running_loss / train_total, 100 * train_correct / train_total))
        # 模型测试
        val_correct = 0
        val_loss = 0.0
        val_total = 0
        current_val_total = 0
        resnet50.eval()  # 测试的时候整个模型的参数不再变化
        # 切记,此处的eval模式开启后,仍然需要关闭梯度记录,否则程序运行会报错OOM
        with torch.no_grad():
            for data in tqdm(test_loader):
                current_val_total += batch_size
                images, labels = data
                if use_gpu:
                    images, labels = Variable(images.cuda()), Variable(labels.cuda())
                else:
                    images, labels = Variable(images), Variable(labels)
                outputs = resnet50(images)
                _, predicted = torch.max(outputs.data, 1)
                loss = cirterion(outputs, labels)
                val_loss += loss.item()
                val_total += labels.size(0)
                val_correct += (predicted == labels.data).sum()
                print("current validation accuracy: %.3f" % (val_correct.item() / current_val_total),
                      "running loss:", loss.item(), end="     ")
        print("-" * 10, "validation result", "-" * 10)
        print('test  %d epoch loss: %.3f  acc: %.3f ' % (
            epoch + 1, val_loss / val_total, 100 * val_correct / val_total))
        train_loss_history.append(running_loss / train_total)
        train_acc_history.append(100 * train_correct / train_total)
        val_loss_history.append(val_loss / val_total)
        val_acc_history.append(100 * val_correct / val_total)

        if val_acc_history and max(val_acc_history) > max_accuracy:
            max_accuracy = max(val_acc_history)
            torch.save(resnet50, 'pytorch4CatanDog\\resnet50.pt')
            # 自己定义的训练好后的模型保存路径
        else:
            print("This epoch does not find a better model than last one.")

    """
    plt.figure(1)
    plt.title('train')
    plt.plot(x, train_acc_history, 'r')
    plt.plot(x, train_loss_history, 'b')
    plt.show()
    plt.figure(2)
    plt.title('test')
    plt.plot(x, val_acc_history, 'r')
    plt.plot(x, val_loss_history, 'b')
    plt.show()
    """
    plot_train_history(epochs_history=x, train_acc_history=train_acc_history, val_acc_history=val_acc_history,
                       train_loss_history=train_loss_history, val_loss_history=val_loss_history)

if __name__ == "__main__":
    train()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值