调试了很久YOLO的标准Pytorch实现,将Train代码进行了重新整理,抽出了标准的Pytorch训练框架。现在整理下来,以备后用。整理后的代码分为三个大部分,每个大部分分为三个小部分:
1、初始化(Init):训练之前先分别创建Model、Dataset&Dataloader、Optimizer;
2、轮次内部(Epoch):分别进行:Dataloader遍历训练、Save模型(间隔)、Eval模型(间隔);
3、训练(Train):其实隶属于Epoch中的Dataloader遍历,最核心的训练步骤:Forward、Backward、Optimize参数;
官方YOLO的Pytorch训练代码整理以后,再简化之后就是下面这样。
其中一些小地方需要注意,例如:在模型进行训练之前,一定要调成训练模式,评估时要调成评估模式,以固定BN层和Dropout层的参数。优化器在定义时要指定需要优化的模型参数。封装输入图像和标签时,标签不需要梯度。优化器使用之后需要清零。
其他注意事项:按照惯例,一些项目上的设定参数都是需要通过argparse传入工程的,为了项目的清晰,我把全部的工程参数设定放到了"__main__"部分,核心的训练部分做为一个独立的函数存在于文件中,这样的安排可以增加代码的可读性,方便整理。
def Quan_train(opt, logger):
### Init Step 1: Create Model
model, device, start_epoch = create_model(opt)
### Init Step 2: Create Dataset
dataloader, train_path, valid_path, class_names = create_dataset(opt)
### Init Step 3: Create Optimizer
optimizer = torch.optim.Adam(model.parameters())
# Epoch
for epoch in range(start_epoch, opt.epochs):
# Set model in train.
model.train()
### Epoch Step 1: Train
for batch_i, (_, imgs, targets) in enumerate(dataloader):
batches_done = len(dataloader) * epoch + batch_i
# Load input and target
imgs = Variable(imgs.to(device))
targets = Variable(targets.to(device), requires_grad=False)
### Train Step 1: Forward pass, get loss
loss, outputs = model(imgs, targets)
### Train Step 2: Backward pass, get gradient
loss.backward()
### Train Step 3: Optimize params
if batches_done % opt.gradient_accumulations: # Accumulates gradient before each step
optimizer.step()
optimizer.zero_grad()
### Epoch Step 2: Save
if epoch % opt.checkpoint_interval == 0:
torch.save(model.state_dict(), f"checkpoints/yolov3-tiny_quan_ckpt_%d.pth" % epoch)
### Epoch Step 3: Eval
if epoch % opt.evaluation_interval == 0:
print("\n---- Evaluating Model ----")
# Evaluate the model on the validation set
precision, recall, AP, f1, ap_class, IoU_total = evaluate()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Super-Params
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=64, help="size of each image batch")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
# ......Other Params
opt = parser.parse_args()
# Set Logger
logger = Logger("logs")
# Set env GPU
os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu)
# Train
Quan_train(opt, logger)