pytorch----retinaface(训练)

pytorch----retinaface(训练)

train.py

from __future__ import print_function
import os
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
import argparse
import torch.utils.data as data
from data import WiderFaceDetection, detection_collate, preproc, cfg_mnet, cfg_re50
from layers.modules import MultiBoxLoss
from layers.functions.prior_box import PriorBox
import time
import datetime
import math
from models.retinaface import RetinaFace


#argparse是一个Python模块:命令行选项、参数和子命令解析器。
#argparse 模块可以让人轻松编写用户友好的命令行接口。程序定义它需要的参数,然后 argparse 将弄清如何从 sys.argv 解析出那些参数。 argparse 模块还会自动生成帮助和使用手册,并在用户给程序传入无效参数时报出错误信息。

#创建解析器
parser = argparse.ArgumentParser(description='Retinaface Training')
#添加参数:
#1:训练集label
parser.add_argument('--training_dataset', default='./data/widerface/train/label.txt', help='Training dataset directory')
#2:network
parser.add_argument('--network', default='mobile0.25', help='Backbone network mobile0.25 or resnet50')
#3:数据加载中使用的工作线程数
parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading')
#4:初始学习率
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
#5:动量
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
#6:恢复网络
parser.add_argument('--resume_net', default=None, help='resume net for retraining')
#7:恢复epoch
parser.add_argument('--resume_epoch', default=0, type=int, help='resume iter for retraining')
# SGD 一次只进行一次更新,就没有冗余,而且比较快,并且可以新增样本。
#8:SGD的重量衰减
parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD')
#9:SGD的Gamma更新
parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD')
#10:模型保存路径
parser.add_argument('--save_folder', default='./weights/', help='Location to save checkpoint models')
#parse_args()是将之前add_argument()定义的参数进行赋值,并返回相关的namespace。
args = parser.parse_args()
#检测有无./weights/文件夹如果没有就创建
if not os.path.exists(args.save_folder):
    os.mkdir(args.save_folder)
#检测使用那种网络模型
cfg = None
if args.network == "mobile0.25":
    cfg = cfg_mnet
elif args.network == "resnet50":
    cfg = cfg_re50
#超参数
#----------------------
rgb_mean = (104, 117, 123) # bgr order分离通道
num_classes = 2#num_classes为标签类别总数
#----------------------
img_dim = cfg['image_size']
num_gpu = cfg['ngpu']
batch_size = cfg['batch_size']
max_epoch = cfg['epoch']
gpu_train = cfg['gpu_train']

num_workers = args.num_workers
momentum = args.momentum
weight_decay = args.weight_decay
initial_lr = args.lr
gamma = args.gamma
training_dataset = args.training_dataset
save_folder = args.save_folder
#生成网络
net = RetinaFace(cfg=cfg)
print("Printing net...")
print(net)

if args.resume_net is not None:
    print('Loading resume network...')
    #加载恢复网络
    state_dict = torch.load(args.resume_net)
    # create new OrderedDict that does not contain `module.`
    #创建不包含“module”的新OrderedDict`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    #灌参数
    for k, v in state_dict.items():
        head = k[:7]
        if head == 'module.':
            name = k[7:] # remove `module.`
        else:
            name = k
        new_state_dict[name] = v
    #提取神经网络
    net.load_state_dict(new_state_dict)
#多gpu运行
if num_gpu > 1 and gpu_train:
    net = torch.nn.DataParallel(net).cuda()
else:
    net = net.cuda()
#增加运行效率
cudnn.benchmark = True

#优化器
optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)
criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)

priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))#生成先验框
with torch.no_grad():#主要是用于停止autograd模块的工作,以起到加速和节省显存的作用
    priors = priorbox.forward()#
    priors = priors.cuda()

def train():
    net.train()#用于训练
    epoch = 0 + args.resume_epoch
    print('Loading Dataset...')

    dataset = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean))

    epoch_size = math.ceil(len(dataset) / batch_size)
    max_iter = max_epoch * epoch_size#最大通道

    stepvalues = (cfg['decay1'] * epoch_size, cfg['decay2'] * epoch_size)#步幅值
    step_index = 0#步幅指标
    #开始通道
    if args.resume_epoch > 0:
        start_iter = args.resume_epoch * epoch_size
    else:
        start_iter = 0
    #开始迭代
    for iteration in range(start_iter, max_iter):
        if iteration % epoch_size == 0:
            # create batch iterator新一轮epoch加载数据,把全部数据又重新加载了,下面的next(batch_iterator)再逐batch_size地取数据
            batch_iterator = iter(data.DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers, collate_fn=detection_collate))#生成迭代器
            if (epoch % 10 == 0 and epoch > 0) or (epoch % 5 == 0 and epoch > cfg['decay1']):
                torch.save(net.state_dict(), save_folder + cfg['name']+ '_epoch_' + str(epoch) + '.pth')#存储pth文件
            epoch += 1

        load_t0 = time.time()
        if iteration in stepvalues:
            step_index += 1
        lr = adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size)

        # load train data
        images, targets = next(batch_iterator) # batch_iterator一次性加载了数据,next操作就逐个batch_size地取出数据了
        images = images.cuda()
        targets = [anno.cuda() for anno in targets]

        # forward
        #前向传播
        out = net(images)

        # backprop
        optimizer.zero_grad()#梯度清0
        #计算损失
        loss_l, loss_c, loss_landm = criterion(out, priors, targets)
        #总损失
        loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
        #反向传播
        loss.backward()
        #更新参数
        optimizer.step()
        load_t1 = time.time()
        batch_time = load_t1 - load_t0
        #预测还有多少时间
        eta = int(batch_time * (max_iter - iteration))
        print('Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || Loc: {:.4f} Cla: {:.4f} Landm: {:.4f} || LR: {:.8f} || Batchtime: {:.4f} s || ETA: {}'
              .format(epoch, max_epoch, (iteration % epoch_size) + 1,
              epoch_size, iteration + 1, max_iter, loss_l.item(), loss_c.item(), loss_landm.item(), lr, batch_time, str(datetime.timedelta(seconds=eta))))

    torch.save(net.state_dict(), save_folder + cfg['name'] + '_Final.pth')
    # torch.save(net.state_dict(), save_folder + 'Final_Retinaface.pth')

#重新设置学习率
def adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size):
    """Sets the learning rate
    # Adapted from PyTorch Imagenet example:
    # https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """
    warmup_epoch = -1
    if epoch <= warmup_epoch:
        lr = 1e-6 + (initial_lr-1e-6) * iteration / (epoch_size * warmup_epoch)
    else:
        lr = initial_lr * (gamma ** (step_index))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

if __name__ == '__main__':
    train()

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值