Person_reID_baseline_pytorch 源码解析之 train.py

脚本 train.py 是用来训练模型的脚本,训练模型首先需要载入数据集,然后开始训练过程,训练完成后可以根据训练结果绘制 loss 曲线图,并保存训练好的模型参数。本文将按照训练模型的流程,分别解析对应步骤的代码。

1. 载入数据集

通过执行数据处理脚本 prepare.py ,我们已经将数据集组织成了 datasets.ImageFolder 可以直接使用的数据集结构。要想将数据集载入模型还需要将数据集张量化并生成数据集迭代器。

1.1 数据集张量化

使用 datasets.ImageFolder 可以将图片格式的数据集变为 pytorch 支持的张量 tensor ,如果对 transform 参数进行设置,则会对数据集的图片进行数据增强等变换。

调用 datasets.ImageFolder 后生成了 pytorch 支持的数据集 image_datasets[‘train’] 和 image_datasets[‘val’] 。

image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                          data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
                                          data_transforms['val'])

可以通过 pytorch 的 transforms 库引入 transform,针对训练集和测试集进行不同的 transform 变化

from torchvision import datasets, transforms
transform_train_list = [
        #transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
        transforms.Resize((h, w), interpolation=3),
        transforms.Pad(10),
        transforms.RandomCrop((h, w)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]

transform_val_list = [
        transforms.Resize(size=(h, w),interpolation=3), #Image.BICUBIC
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
        
data_transforms = {
    'train': transforms.Compose(transform_train_list),
    'val': transforms.Compose(transform_val_list),
}

1.2 数据集迭代器

训练模型时,一般不会一次性把所有数据都加载到模型中。通常采用 mini_batch 的方法,按照 batchsize 的大小将一个 batch 的数据载入到模型中。pytorch 框架支持用 torch.utils.data.DataLoader 作为 dataloader 载入数据。

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                             shuffle=True, num_workers=0, pin_memory=True) # 8 workers may work faster
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

将 image_datasets[‘train’] 和 image_datasets[‘val’] 输入 torch.utils.data.DataLoader 后,获得了两个迭代器 dataloaders[‘train’] and dataloaders[‘val’] 。

下面来介绍一下 torch.utils.data.DataLoader 的主要参数

class torch.utils.data.DataLoader(dataset, 
								batch_size=1, 
								shuffle=False, 
								sampler=None, 
								num_workers=0, 
								collate_fn=<function default_collate>, 
								pin_memory=False, 
								drop_last=False)

torch.utils.data.DataLoader 将返回一个数据迭代器。

参数说明:

  • dataset (Dataset) – 加载数据的数据集
  • batch_size (int) – 每个batch加载多少个样本(默认: 1)
  • shuffle (bool) – 设置为True时会在每个epoch重新打乱数据(默认: False)
  • sampler (Sampler) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数
  • num_workers (int) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
  • drop_last (bool, optional) – 如果数据集大小不能被 batch size 整除,则设置为 True 后可删除最后一个不完整的batch。如果设为 False 并且数据集的大小不能被 batch size 整除,则最后一个batch将更小。(默认: False)

2. 开始训练

在函数 train_model 中,实现了模型训练过程。网络模型一般会迭代多轮以达到一个很好的训练效果,通常通过循环执行一段训练代码来实现迭代训练。

2.1 训练代码

下面对主要的训练代码进行解析:

			# Iterate over data.
            for data in dataloaders[phase]:
                # 载入一个 batch 的输入
                # 数据迭代器返回一个 batch 的图像及其标签
                inputs, labels = data
                now_batch_size,c,h,w = inputs.shape
                if now_batch_size<opt.batchsize: # skip the last batch
                    continue
                # print(inputs.shape)
                # 变量化输入
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)
				# 开始训练
                # 将梯度参数置零
                optimizer.zero_grad()
				
				# 前向传播,计算损失
                #-------- forward --------
                outputs = model(inputs)
                # preds 是 softmax 概率最大的类别的索引, 即模型预测的类别
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)
				
				# 只在 train 模式下执行,反向传播,梯度下降优化, 
                #-------- backward + optimize -------- 
                # only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

训练过程中,还可以使用 warm_up 等学习率策略。

2.2 模型加载

模型训练过程中,还会涉及到模型加载。在训练模式下,模型的网络参数会发生改变;而在验证模式下,一般不进行梯度下降反向传播等操作,我们希望网络参数保持不变。此时会考虑使用 model.load_state_dict 加载最佳模型参数进行验证。

注意
model.load_state_dict 是深拷贝,可以保证加载的是最佳模型参数
model.state_dict 是浅拷贝,保存的是最后一轮训练的模型参数

另外使用预训练迁移模型的部分层参数时,记得令 strict=False,即
model.load_state_dict(state_dict, strict=False)。strict 默认为 True,表示严格按照名称加载参数,如果出现未定义的名称,就会报错。如果将 strict=False,则会忽略未定义的名称,不会报错。

            # deep copy the model
            if phase == 'val':
                last_model_wts = model.state_dict()
                if epoch%10 == 9:
                    save_network(model, epoch)
                draw_curve(epoch)
            if phase == 'train':
               scheduler.step()
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    #print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(last_model_wts)
    save_network(model, 'last')

3. 结果保存

训练过程中,一般会保存训练好的模型参数,方便下次训练时加载模型。为了监控训练过程,一般还会绘制 loss 曲线。

3.1 模型保存

baseline 通过 torch.save 实现模型参数的保存,具体代码如下:

# Save model
#---------------------------
def save_network(network, epoch_label):
    save_filename = 'net_%s.pth'% epoch_label
    # save_path = os.path.join('./model',name,save_filename)
    save_path = os.path.join('model', name, save_filename)
    torch.save(network.cpu().state_dict(), save_path)
    if torch.cuda.is_available():
        network.cuda(gpu_ids[0])

pytorch 一般使用如下代码实现模型的保存和加载

# save
torch.save(model.state_dict(), PATH)

# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

3.2 loss 曲线绘制

使用 pyplot 库可以实现绘图,loss 曲线绘制代码如下:

# Draw Curve
#---------------------------
import matplotlib.pyplot as plt
x_epoch = []
fig = plt.figure()
ax0 = fig.add_subplot(121, title="loss")
ax1 = fig.add_subplot(122, title="top1err")
def draw_curve(current_epoch):
    x_epoch.append(current_epoch)
    ax0.plot(x_epoch, y_loss['train'], 'bo-', label='train')
    ax0.plot(x_epoch, y_loss['val'], 'ro-', label='val')
    ax1.plot(x_epoch, y_err['train'], 'bo-', label='train')
    ax1.plot(x_epoch, y_err['val'], 'ro-', label='val')
    if current_epoch == 0:
        ax0.legend()
        ax1.legend()
    # fig.savefig( os.path.join('./model',name,'train.jpg'))
    fig.savefig(os.path.join('model', name, 'train.jpg'))

参考文献

  1. 从零开始行人重识别
  2. Person_reID_baseline_pytorch
  3. torch.max()使用讲解
  4. 源码详解Pytorch的state_dict和load_state_dict
  5. Pytorch踩坑记:赋值、浅拷贝、深拷贝三者的区别以及model.state_dict()和model.load_state_dict()的坑点
  6. torch.load_state_dict()函数的用法总结
  • 4
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值