前言
本文实现在Linux服务器上用自己的数据集训练YOLO-NAS,看CSDN上YOLO-NAS的教程基本都需要开专栏,自己花了很长时间才成功训练,所以就总结一下供大家参考
一、环境配置
新建conda环境,本文采用python3.10,不知道conda命令的的可以参考这篇文章
下载super-gradients库,建议不带版本,我带版本会报错
pip install super-gradients
二、数据准备
参考这篇文章,但稍有修改,最终的data文件夹如下所示:
.yaml文件不用管
三、训练
3.1 编写train.py文件
需要修改第33行你的类别名称;第45行选择你的模型,s m l模型依次变大,训练时间会变长,精度更好;第104行修改epoch
import os
import requests
import torch
from PIL import Image
from super_gradients.training import Trainer, dataloaders, models
from super_gradients.training.dataloaders.dataloaders import (
coco_detection_yolo_format_train, coco_detection_yolo_format_val
)
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import (
PPYoloEPostPredictionCallback
)
class config:
# trainer params
CHECKPOINT_DIR = 'checkpoints' # specify the path you want to save checkpoints to
EXPERIMENT_NAME = 'football' # specify the experiment name
# dataset params
DATA_DIR = 'data' # parent directory to where data lives
TRAIN_IMAGES_DIR = 'images/train' # child dir of DATA_DIR where train images are
TRAIN_LABELS_DIR = 'labels/train' # child dir of DATA_DIR where train labels are
VAL_IMAGES_DIR = 'images/val' # child dir of DATA_DIR where validation images are
VAL_LABELS_DIR = 'labels/val' # child dir of DATA_DIR where validation labels are
# TEST_IMAGES_DIR = 'images/test' # child dir of DATA_DIR where validation images are
# TEST_LABELS_DIR = 'labels/test' # child dir of DATA_DIR where validation labels are
CLASSES = ['ball', 'goalkeeper', 'player', 'referee'] # 指定类名
NUM_CLASSES = len(CLASSES) # 获取类个数
# dataloader params - you can add whatever PyTorch dataloader params you have
# could be different across train, val, and test
DATALOADER_PARAMS = {
'batch_size': 16,
'num_workers': 2
}
# model params
MODEL_NAME = 'yolo_nas_m' # 可以选择 yolo_nas_s, yolo_nas_m, yolo_nas_l。分别是 小型,中型,大型
PRETRAINED_WEIGHTS = 'coco' # only one option here: coco
trainer = Trainer(experiment_name=config.EXPERIMENT_NAME, ckpt_root_dir=config.CHECKPOINT_DIR)
# 指定训练数据
train_data = coco_detection_yolo_format_train(
dataset_params={
'data_dir': config.DATA_DIR,
'images_dir': config.TRAIN_IMAGES_DIR,
'labels_dir': config.TRAIN_LABELS_DIR,
'classes': config.CLASSES
},
dataloader_params=config.DATALOADER_PARAMS
)
# 指定评估数据
val_data = coco_detection_yolo_format_val(
dataset_params={
'data_dir': config.DATA_DIR,
'images_dir': config.VAL_IMAGES_DIR,
'labels_dir': config.VAL_LABELS_DIR,
'classes': config.CLASSES
},
dataloader_params=config.DATALOADER_PARAMS
)
# test_data = coco_detection_yolo_format_val(
# dataset_params={
# 'data_dir': config.DATA_DIR,
# 'images_dir': config.TEST_IMAGES_DIR,
# 'labels_dir': config.TEST_LABELS_DIR,
# 'classes': config.CLASSES
# },
dataloader_params=config.DATALOADER_PARAMS
# )
# train_data.dataset.plot()
model = models.get(config.MODEL_NAME,
num_classes=config.NUM_CLASSES,
pretrained_weights=config.PRETRAINED_WEIGHTS
)
train_params = {
# ENABLING SILENT MODE
"average_best_models":True,
"warmup_mode": "linear_epoch_step",
"warmup_initial_lr": 1e-6,
"lr_warmup_epochs": 3,
"initial_lr": 5e-4,
"lr_mode": "cosine",
"cosine_final_lr_ratio": 0.1,
"optimizer": "Adam",
"optimizer_params": {"weight_decay": 0.0001},
"zero_weight_decay_on_bias_and_bn": True,
"ema": True,
"ema_params": {"decay": 0.9, "decay_type": "threshold"},
# ONLY TRAINING FOR 10 EPOCHS FOR THIS EXAMPLE NOTEBOOK
"max_epochs": 200,
"mixed_precision": True,
"loss": PPYoloELoss(
use_static_assigner=False,
# NOTE: num_classes needs to be defined here
num_classes=config.NUM_CLASSES,
reg_max=16
),
"valid_metrics_list": [
DetectionMetrics_050(
score_thres=0.1,
top_k_predictions=300,
# NOTE: num_classes needs to be defined here
num_cls=config.NUM_CLASSES,
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(
score_threshold=0.01,
nms_top_k=1000,
max_predictions=300,
nms_threshold=0.7
)
)
],
"metric_to_watch": 'mAP@0.50'
}
trainer.train(model=model,
training_params=train_params,
train_loader=train_data,
valid_loader=val_data)
best_model = models.get(config.MODEL_NAME,
num_classes=config.NUM_CLASSES,
checkpoint_path=os.path.join(config.CHECKPOINT_DIR, config.EXPERIMENT_NAME, 'average_model.pth'))
3.2 开始训练
python train.py即可开始训练
结语
写的有点仓促,后续再更新,有问题评论区交流