行人重识别(20)——代码实践之模型训练(train_model(局部对齐模型).py)

!转载请注明原文地址!——东方旅行者

更多行人重识别文章移步我的专栏:行人重识别专栏

train_model(局部对齐模型).py

一、train_model(局部对齐模型).py作用

本文件是行人重识别系统的核心文件,用于局部对齐模型训练。该文件由train_model.py(度量学习)文件进行部分修改形成。

二、train_model(局部对齐模型).py与train_model(度量学习).py的不同

需要修改的地方有:

  1. 在引用中加入对“全局特征与局部特征难样本挖掘三元组损失”与“局部对齐最小距离算法”的引入(第14,16行)
  2. 在指定度量损失函数时,使用引入的带难样本挖掘的全局与局部三元组损失(第94行)
  3. 计算度量损失返回一个全局度量损失,一个局部度量损失,两者都使用了难样本挖掘。并且在计算总损失时也是分类损失、全局度量损失与局部度量损失加和。(第145,146行)

三、代码

import os,sys,time,datetime
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from IPython import embed

import transform as T
from model.ReIDNet import ReIDNet
from dataset_manager import Market1501
from dataset_loader import ImageDataset
from LG_TriHard_Loss import AlignedTripletLoss#!!与度量学习不同!!引入全局特征与局部特征难样本挖掘三元组损失
from sampler import RandomIdentitySampler
from local_distance import batch_local_dist#!!与度量学习不同!!引入局部对齐最小距离算法

"""
本文件是行人重识别系统的核心文件,用于局部对齐模型训练。
"""
#设定输入参数
width=64                    #图片宽度
height=128                 #图片高度
train_batch_size=32  #训练批量
test_batch_size=32  #测试批量
train_lr=0.01                #学习率
start_epoch=0           #开始训练的批次
end_epoch=1                 #结束训练的批次
dy_step_size=800      #动态学习率变化步长
dy_step_gamma=0.9  #动态学习率变化倍数
evaluate=False           #是否测试
margin=0.3                 #TripletHard Loss计算的margin参数
num_instances=4        #每个ID图片数,一定要能被batch_size整除
metric_only=False      #是否只用TriHardLoss
max_acc=-1#最大准确率
best_model_path='./model/param/aligned_trihard_net_params_best.pth'#最优模型保存地址
final_model_path='./model/param/aligned_trihard_net_params_final.pth'#最终模型保存地址

def main():
    #数据集加载
    dataset=Market1501()
    
    #训练数据处理器
    transform_train=T.Compose(
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值