开始
内容参考:Datawhale Task04:不讲武德-炼丹与品尝 终于,神功初成,可以开始施展拳脚了
一· 模型训练
目标检测网络的训练大致是如下的流程:
设置各种超参数
定义数据加载模块
定义网络模型
定义损失函数
定义优化器
遍历训练数据,预测-计算损失-反向传播
首先,引入必要的库,然后设定各种超参数
整体代码为:
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from model import tiny_detector, MultiBoxLoss
from datasets import PascalVOCDataset
from utils import *
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True
# Data parameters
data_folder = '../../../dataset/VOCdevkit' # data files root path
keep_difficult = True # use objects considered difficult to detect?
n_classes = len(label_map) # number of different types of objects
# Learning parameters
total_epochs = 230 # number of epochs to train
batch_size = 8 # batch size
workers = 4 # number of workers for loading data in the DataLoader
print_freq = 100 # print training status every __ batches
lr = 1e-3 # learning rate
decay_lr_at = [150, 190] # decay learning rate after these many epochs
decay_lr_to = 0.1 # decay learning rate to this fraction of the existing learning rate
momentum = 0.9 # momentum
weight_decay = 5e-4 # weight decay
def main():
"""
Training.
"""
# Initialize model and optimizer
model = tiny_detector(n_classes=n_classes)
criterion = MultiBoxLoss(priors_cxcy=model.priors_cxcy)
optimizer = torch.optim.SGD(params=model.parameters(),
lr=lr, momentum=momentum, weight_decay=weight_decay)
# Move to default device
model = model.to(device)
criterion = criterion.to(device)
# Custom dataloaders
train_dataset = PascalVOCDataset(data_folder,
split='train',
keep_difficult=keep_difficult)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
collate_fn=train_dataset.collate_fn, num_workers=workers,
pin_memory=True) # note that we're passing the collate function here
start_epoch = 0
if os.path.exists('checkpoint.pth.tar'):
checkpoint = torch.load('checkpoint.pth.tar')
model = checkpoint["model"]
start_epoch = checkpoint["epoch"] + 1
optimi