pytorch目标检测ssd七__训练代码与loss组成解析

本篇博客是我学习(https://blog.csdn.net/weixin_44791964)博主写的pytorch的ssd的博客后写的,大家可以直接去看这位博主的博客(https://blog.csdn.net/weixin_44791964/article/details/104981486)。这位博主在b站还有配套视频,传送门:(https://www.bilibili.com/video/BV1A7411976Z)。这位博主的在GitHub的源代码(https://github.com/bubbliiiing/ssd-pytorch)。 侵删

这篇博客主要是理清楚ssd目标检测算法的训练思路

下面就是训练文件的代码了,注释都在代码里面

from nets.ssd import get_ssd
from nets.ssd_training import Generator,MultiBoxLoss
from utils.config import Config
#from torchsummary import summary
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import time
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
def adjust_learning_rate(optimizer, lr, gamma, step):
    lr = lr * (gamma ** (step))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

if __name__ == "__main__":
    Batch_size = 4
    lr = 1e-5
    Epoch = 50
    Cuda = False
    Start_iter = 0
    # 需要使用device来指定网络在GPU还是CPU运行
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #获得ssd目标检测算法的模型
    model = get_ssd("train",Config["num_classes"])

    #载入我们与训练好的预训练的模型,类似于迁移学习的思想,但是嗷,这里用的是gpu训练出来的参数,我电脑没有gpu,实验室电脑连不上,真的dmn了嗷
    print('Loading weights into state dict...')
    model_dict = model.state_dict()
    #pretrained_dict = torch.load("model_data/ssd_weights.pth")
    #pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
    #model_dict.update(pretrained_dict)
    #model.load_state_dict(model_dict)
    print('Finished!')

    #设置了模型的cuda参数
    net = model
    if Cuda:
        net = torch.nn.DataParallel(model)
        cudnn.benchmark = True
        net = net.cuda()

    """
    2007_train.txt这个其实是我们执行voc_annotation.py之后生成的文件,
    这个文件里面存放了图片的路径和他所对应的目标
    """
    annotation_path = '2007_train.txt'
    with open(annotation_path) as f:
        lines = f.readlines()
    np.random.seed(10101)
    #打开文件之后进行一个shuffle的打乱
    np.random.shuffle(lines)
    np.random.seed(None)
    num_train = len(lines)

    """
    使用Generator来对我们的图片进行一次预处理,
    Generator会利用2007_train.txt文件去生成图片和对应的标签
    """
    gen = Generator(Batch_size, lines,
                    (Config["min_dim"], Config["min_dim"]), Config["num_classes"]).generate()


    #设置优化器
    optimizer = optim.Adam(net.parameters(), lr=lr)
    #MultiBoxLoss是ssd使用的loss函数
    criterion = MultiBoxLoss(Config['num_classes'], 0.5, True, 0, True, 3, 0.5,
                             False, Cuda)

    net.train()


    epoch_size = num_train // Batch_size
    for epoch in range(Start_iter,Epoch):
        if epoch%10==0:
            adjust_learning_rate(optimizer,lr,0.95,epoch)
        loc_loss = 0
        conf_loss = 0
        #首先取出一个batch来进行训练
        for iteration in range(epoch_size):
            images, targets = next(gen)
            with torch.no_grad():
                if Cuda:
                    #将图片和target变成变量的形式
                    images = Variable(torch.from_numpy(images).cuda().type(torch.FloatTensor))
                    targets = [Variable(torch.from_numpy(ann).cuda().type(torch.FloatTensor)) for ann in targets]
                else:
                    images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
                    targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
            # 前向传播
            out = net(images)
            # 清零梯度
            optimizer.zero_grad()
            # 计算loss
            loss_l, loss_c = criterion(out, targets)
            loss = loss_l + loss_c
            # 反向传播
            loss.backward()
            optimizer.step()
            # 加上
            loc_loss += loss_l.item()
            conf_loss += loss_c.item()

            print('\nEpoch:'+ str(epoch+1) + '/' + str(Epoch))
            print('iter:' + str(iteration) + '/' + str(epoch_size) + ' || Loc_Loss: %.4f || Conf_Loss: %.4f ||' % (loc_loss/(iteration+1),conf_loss/(iteration+1)), end=' ')

        #每一个batch进行一次权重的保存
        print('Saving state, iter:', str(epoch+1))
        torch.save(model.state_dict(), 'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'%((epoch+1),loc_loss/(iteration+1),conf_loss/(iteration+1)))

首先就是读取文件,然后利用Generator获得图片及其对应的标签,然后就是基本的训练了

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值