train.py
from __future__ import division
from models import *
from utils.logger import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *
from test import evaluate
from terminaltables import AsciiTable
import os
import sys
import time
import datetime
import argparse
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
"""
train.py的主要工作流程
1.解析输入的各种参数,如没有则使用默认参数
2.打印各种参数
3.初始化日志
4.获得train_path、valid_path和class_names的文件路径
5.创建model,随机初始化权重,也可以加载预训练的参数
6.加载训练图像
7.选择优化器
8.开始epoch轮,反向传播
9.开始训练batch_i批
10.每累积gradient_accumulations批,进行一次梯度下降
11.记录训练日志
12.每训练完evaluation_interval轮输出一次评估结果
13.每训练完opt.checkpoint_interval轮,保存一次checkpoints
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=8, help="size of each image batch")
parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
parser.add_argument("--data_config", type=str, default="config/coco.data", help="path to data config file")
parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
parser.add_argument("--evaluation_interval", type=int, default=1, help="interval evaluations on validation set")
parser.add_argument("--compute_map", default=False, help="if True computes mAP every tenth batch")
parser.add_argument("--multiscale_training", default=True, help="allow for multi-scale training")
opt = parser.parse_args()
print(opt)
logger = Logger("logs")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("output", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
data_config = parse_data_config(opt.data_config)
train_path = data_config["train"]
valid_path = data_config["valid"]
class_names = load_classes(data_config["names"])
model = Darknet(opt.model_def).to(device)
model.apply(weights_init_normal)
if opt.pretrained_weights:
if opt.pretrained_weights.endswith(".pth"):
model.load_state_dict(torch.load(opt.pretrained_weights))
else:
model.load_darknet_weights(opt.pretrained_weights)
dataset = ListDataset(train_path, augment=True, multiscale=opt.multiscale_training)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
pin_memory=True,
collate_fn=dataset.collate_fn,
)
optimizer = torch.optim.Adam(model.parameters())
metrics = [
"grid_size",
"loss",
"x",
"y",
"w",
"h",
"conf",
"cls",
"cls_acc",
"recall50",
"recall75",
"precision",
"conf_obj",
"conf_noobj",
]
for epoch in range(opt.epochs):
model.train()
start_time = time.time()
for batch_i, (_, imgs, targets) in enumerate(dataloader):
batches_done = len(dataloader) * epoch + batch_i
imgs = Variable(imgs.to(device))
targets = Variable(targets.to(device), requires_grad=False)
"""
这里进行计算loss。其实这个loss的计算是在yolo层计算的,其实不难理解,yolo层是负责目标检测的层,
需要输出目标的类别、坐标、大小,所以会在这一层进行loss计算。
这个代码可以从Darknet类的前向通路中发现(在训练的时候targets是有值的,不等于None)
yolo层的具体实现是在YOLOLayer中,可查看其forward函数得知loss计算过程,代码(YOLOLayer部分)
"""
loss, outputs = model(imgs, targets)
loss.backward()
if batches_done % opt.gradient_accumulations:
optimizer.step()
optimizer.zero_grad()
log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (epoch, opt.epochs, batch_i, len(dataloader))
metric_table = [["Metrics", *[f"YOLO Layer {i}" for i in range(len(model.yolo_layers))]]]
for i, metric in enumerate(metrics):
formats = {m: "%.6f" for m in metrics}
formats["grid_size"] = "%2d"
formats["cls_acc"] = "%.2f%%"
row_metrics = [formats[metric] % yolo.metrics.get(metric, 0) for yolo in model.yolo_layers]
metric_table += [[metric, *row_metrics]]
tensorboard_log = []
for j, yolo in enumerate(model.yolo_layers):
for name, metric in yolo.metrics.items():
if name != "grid_size":
tensorboard_log += [(f"{name}_{j+1}", metric)]
tensorboard_log += [("loss", loss.item())]
logger.list_of_scalars_summary(tensorboard_log, batches_done)
log_str += AsciiTable(metric_table).table
log_str += f"\nTotal loss {loss.item()}"
epoch_batches_left = len(dataloader) - (batch_i + 1)
time_left = datetime.timedelta(seconds=epoch_batches_left * (time.time() - start_time) / (batch_i + 1))
log_str += f"\n---- ETA {time_left}"
print(log_str)
model.seen += imgs.size(0)
if epoch % opt.evaluation_interval == 0:
print("\n---- Evaluating Model ----")
precision, recall, AP, f1, ap_class = evaluate(
model,
path=valid_path,
iou_thres=0.5,
conf_thres=0.5,
nms_thres=0.5,
img_size=opt.img_size,
batch_size=8,
)
evaluation_metrics = [
("val_precision", precision.mean()),
("val_recall", recall.mean()),
("val_mAP", AP.mean()),
("val_f1", f1.mean()),
]
logger.list_of_scalars_summary(evaluation_metrics, epoch)
ap_table = [["Index", "Class name", "AP"]]
for i, c in enumerate(ap_class):
ap_table += [[c, class_names[c], "%.5f" % AP[i]]]
print(AsciiTable(ap_table).table)
print(f"---- mAP {AP.mean()}")
if epoch % opt.checkpoint_interval == 0:
torch.save(model.state_dict(), f"checkpoints/yolov3_ckpt_%d.pth" % epoch)