(二)代码架构拆分

总述

下面首先会给出一个toy model以及其基本的训练流程,然后再对其进行拓展加入三个xx模块,分别是(1)管理模型超参数的argparse(2)save&load模块(3)监视训练过程的SummaryWriter

一.code

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np

from torch.utils.tensorboard import SummaryWriter

import argparse
import os
import shutil

def prepare_folders(args):
    #因为save不会生成文件夹,所以需要提前创建
    
    folders_util = [args.root_model,
                    os.path.join(args.root_model, args.store_name)]
    for folder in folders_util:
        if not os.path.exists(folder):
            print('creating folder ' + folder)
            os.mkdir(folder)

class AverageMeter(object):
    
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def save_checkpoint(args, state, is_best):
    
    filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))
    
##########################################################################################
#以上是utils
##########################################################################################
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def train_loop(dataloader, model, loss_fn, optimizer, device, tf_writer, epoch):
    #one epoch train
    running_loss = 0.0
    top1 = AverageMeter('Acc@1', ':6.2f')
    losses = AverageMeter('Loss', ':.4e')
    
    for i, (inputs, labels) in enumerate(dataloader, 0):
        # 获取输入,提取特征
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        
        # 计算预测值和损失
        loss = loss_fn(outputs, labels)
        losses.update(loss.item(), inputs.size(0))
        acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
        top1.update(acc1[0], inputs.size(0))
        
        # 反向传播优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        #训练日志(可以检测代码的运行情况)
        running_loss += loss.item()#每个batch的loss汇总
        if (i + 1) % 100 == 0:
            print('[Batch%4d] loss: %.3f' %
                  (i + 1, running_loss / 100))
            running_loss = 0.0
        
            
    tf_writer.add_scalar('loss/train', losses.avg, epoch)
    tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)

def test_loop(dataloader, model, loss_fn,device,tf_writer,epoch):
    correct = 0
    total = 0
    flag = 'val'
    top1 = AverageMeter('Acc@1', ':6.2f')
    losses = AverageMeter('Loss', ':.4e')
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            loss = loss_fn(outputs, labels)
            losses.update(loss.item(), images.size(0))
            acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
            top1.update(acc1[0], images.size(0))
            
    print('Accuracy of the network on the 10000 test images: %d %%' % (
            100 * correct / total))
    
    tf_writer.add_scalar('loss/test_'+ flag, losses.avg, epoch)
    tf_writer.add_scalar('acc/test_' + flag + '_top1', top1.avg, epoch)
    
    return top1.avg

def parse_args():
    parser = argparse.ArgumentParser(description='toy model' )
    #training based parameter
    parser.add_argument('--dataset', default='cifar10', help='dataset setting')
    parser.add_argument('--loss_type', default="CE", type=str, help='loss type')
    parser.add_argument('--root_model', type=str, default='/kaggle/working/checkpoint')
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=50)
    
    #early stop parameter
    parser.add_argument('--delta', type=float, default=0.)
    parser.add_argument('--patience', type=int, default=5)

    return parser.parse_args()


if __name__ == '__main__':
    ##初始化过程
    #初始化模型超参数
    args = parse_args()
    args.store_name = '_'.join([args.dataset, args.loss_type, str(args.epochs)])
    prepare_folders(args)
    
    # 设备
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    
    #训练过程可视化
    writer = SummaryWriter("logs")

    ##准备数据集
    #预处理
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # 标准化图像数据
    
    #train
    trainset = datasets.CIFAR10(root='./cifar10_data', train=True,
                                download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=128,
                             shuffle=True, num_workers=2)
    
    #test
    testset = datasets.CIFAR10(root='./cifar10_data', train=False,
                               download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=128,
                            shuffle=False, num_workers=2)
    
    ## 超参数
    #early stop
    counter = 0
    early_stop = []
    best_acc1 = 0
    
    # 模型实例
    model = LeNet5().to(device)
    # 损失函数实例
    loss_fn = nn.CrossEntropyLoss()
    # 优化器实例
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    for t in range(args.epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        train_loop(trainloader, model, loss_fn, optimizer, device=device,tf_writer=writer,epoch = t)
        acc1 = test_loop(testloader, model,loss_fn, device=device,tf_writer=writer,epoch =t)
    
        ##下面一段属于早停的部分,因为是临时加的所以没有封装显得有些多余,不过我就是想偷个懒
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        
        #savem_model
        if acc1 < best_acc1 + args.delta:#用正确率约束,没有提高就加1
            counter += 1
            print(f'EarlyStopping counter: {counter} out of {args.patience}')
            if counter >= args.patience:#超过patience就早停
                early_stop = True
        else:
            save_checkpoint(args, {
                'epoch': t + 1,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best)
            counter = 0

        if early_stop:
            print("Early stopping")
            # 结束模型训练
            break
    
    writer.close()
    print("Done!")

二.argparse

如果我们对参数是随便用随便命名的话,那么初始化的变量会弥漫在代码之中,非常不便于管理,所以这个库就提供了一些的方法。

1.基本框架

import argparse
parser = argparse.ArgumentParser(description='toy model' )
...
parser.add_argument('--str',default=,type=,help=)
...
args = parser.parse_args()

第二句中的description就随便给个字符串名字就行
第三句中str替换为超参数名称,default是超参数的默认取值,type是取值类型,help是命令行会提示你这个变量应该怎么去给值(.add_argumnet()的常用参数就是这些,更多的可以查看专门的文档或者通过代码实验)

2.命令行执行程序

在所有代码之前先加上一句

%%writefile  toy_model.py

这句话的意思就是把这个cell的代码,保存为toy_model.py文件,文件名可以进行替换的
运行之后得到如下结果:
在这里插入图片描述
得到了py文件后,再去cell中用命令行的语言运行即可

!python /kaggle/working/toy_model.py --epoch 1

在这里插入图片描述

3.要点说明

(1)如果是用命令行代码运行程序的话,需要保证parser.parse_args()没有输入的参数,这样就能在命令行自己设置一些参数;如果是希望全部用default值,并且希望直接在cell中运行,需要加入args=[]的输入,如:parser.parse_args(args=[]),不加就会如下报错。
在这里插入图片描述

(2)toy model中是用一个函数封装的该部分,这样会使得代码更加简洁,逻辑性更强

三.save&load

四.SummaryWriter

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值