目标检测(四)训练与测试

目标检测(四)训练与测试

开始

内容参考:Datawhale Task04:不讲武德-炼丹与品尝 终于,神功初成,可以开始施展拳脚了

一· 模型训练

目标检测网络的训练大致是如下的流程:

设置各种超参数
定义数据加载模块
定义网络模型
定义损失函数
定义优化器
遍历训练数据,预测-计算损失-反向传播

首先,引入必要的库,然后设定各种超参数
整体代码为:

import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from model import tiny_detector, MultiBoxLoss
from datasets import PascalVOCDataset
from utils import *
import os


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True

# Data parameters
data_folder = '../../../dataset/VOCdevkit'  # data files root path
keep_difficult = True  # use objects considered difficult to detect?
n_classes = len(label_map)  # number of different types of objects

# Learning parameters
total_epochs = 230 # number of epochs to train
batch_size = 8  # batch size
workers = 4  # number of workers for loading data in the DataLoader
print_freq = 100  # print training status every __ batches
lr = 1e-3  # learning rate
decay_lr_at = [150, 190]  # decay learning rate after these many epochs
decay_lr_to = 0.1  # decay learning rate to this fraction of the existing learning rate
momentum = 0.9  # momentum
weight_decay = 5e-4  # weight decay


def main():
    """
    Training.
    """
    # Initialize model and optimizer
    model = tiny_detector(n_classes=n_classes)
    criterion = MultiBoxLoss(priors_cxcy=model.priors_cxcy)
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=lr, momentum=momentum, weight_decay=weight_decay)

    # Move to default device
    model = model.to(device)
    criterion = criterion.to(device)

    # Custom dataloaders
    train_dataset = PascalVOCDataset(data_folder,
                                     split='train',
                                     keep_difficult=keep_difficult)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                               collate_fn=train_dataset.collate_fn, num_workers=workers,
                                               pin_memory=True)  # note that we're passing the collate function here

    start_epoch = 0
    if os.path.exists('checkpoint.pth.tar'):
        checkpoint = torch.load('checkpoint.pth.tar')
        model = checkpoint["model"]
        start_epoch = checkpoint["epoch"] + 1
        optimi
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Transformer 目标检测训练是一种使用 Transformer 模型进行目标检测任务训练的方法。传统的目标检测方法主要是基于卷积神经网络 (Convolutional Neural Networks, CNNs) 进行特征提取和分类,而 Transformer 模型则是一种基于自注意力机制的深度学习模型,用于处理序列数据。 在使用 Transformer 进行目标检测训练时,常见的方法是将输入图像划分为一系列不同尺度的区域,然后将这些区域转换为序列数据。每个区域都会被编码成一个向量,并通过 Transformer 模型进行处理。这样可以捕捉到不同区域之间的关系和上下文信息,从而提高目标检测的准确性。 通常情况下,Transformer 目标检测训练包括以下步骤: 1. 数据准备:收集、标注和预处理训练数据集,包括图像和相应的目标框标注。 2. 特征提取:使用预训练的卷积神经网络 (如 ResNet 或 VGG) 对输入图像进行特征提取。 3. 区域划分:将图像划分为不同尺度的区域,并将每个区域编码为向量表示。 4. 序列转换:使用 Transformer 模型对区域向量序列进行处理,以获取上下文信息和关系。 5. 目标分类和边界框回归:使用分类器对每个区域进行目标分类,并回归出边界框的位置。 6. 损失计算和反向传播:计算预测结果与真实标注之间的损失,并通过反向传播优化模型参数。 7. 模型评估和调优:使用验证集评估模型性能,并进行参数调优和模型选择。 8. 测试与推理:使用训练好的模型对新的图像进行目标检测推理。 需要注意的是,由于 Transformer 模型在处理图像数据时相对较慢,通常需要结合其他技术或优化策略来加速训练和推理过程。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值