Linux服务器训练YOLO-NAS


前言

本文实现在Linux服务器上用自己的数据集训练YOLO-NAS,看CSDN上YOLO-NAS的教程基本都需要开专栏,自己花了很长时间才成功训练,所以就总结一下供大家参考


一、环境配置

新建conda环境,本文采用python3.10,不知道conda命令的的可以参考这篇文章
下载super-gradients库,建议不带版本,我带版本会报错

pip install super-gradients

二、数据准备

参考这篇文章,但稍有修改最终的data文件夹如下所示:
在这里插入图片描述
.yaml文件不用管


三、训练

3.1 编写train.py文件

需要修改第33行你的类别名称;第45行选择你的模型,s m l模型依次变大,训练时间会变长,精度更好;第104行修改epoch

import os

import requests
import torch
from PIL import Image

from super_gradients.training import Trainer, dataloaders, models
from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train, coco_detection_yolo_format_val
)
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import (
    PPYoloEPostPredictionCallback
)

class config:
    # trainer params
    CHECKPOINT_DIR = 'checkpoints'  # specify the path you want to save checkpoints to
    EXPERIMENT_NAME = 'football'  # specify the experiment name

    # dataset params
    DATA_DIR = 'data'  # parent directory to where data lives

    TRAIN_IMAGES_DIR = 'images/train'  # child dir of DATA_DIR where train images are
    TRAIN_LABELS_DIR = 'labels/train'  # child dir of DATA_DIR where train labels are

    VAL_IMAGES_DIR = 'images/val'  # child dir of DATA_DIR where validation images are
    VAL_LABELS_DIR = 'labels/val'  # child dir of DATA_DIR where validation labels are

    # TEST_IMAGES_DIR = 'images/test'  # child dir of DATA_DIR where validation images are
    # TEST_LABELS_DIR = 'labels/test'  # child dir of DATA_DIR where validation labels are
    CLASSES = ['ball', 'goalkeeper', 'player', 'referee']  # 指定类名

    NUM_CLASSES = len(CLASSES) # 获取类个数

    # dataloader params - you can add whatever PyTorch dataloader params you have
    # could be different across train, val, and test
    DATALOADER_PARAMS = {
        'batch_size': 16,
        'num_workers': 2
    }

    # model params
    MODEL_NAME = 'yolo_nas_m'  # 可以选择 yolo_nas_s, yolo_nas_m, yolo_nas_l。分别是 小型,中型,大型
    PRETRAINED_WEIGHTS = 'coco'  # only one option here: coco

trainer = Trainer(experiment_name=config.EXPERIMENT_NAME, ckpt_root_dir=config.CHECKPOINT_DIR)

# 指定训练数据
train_data = coco_detection_yolo_format_train(
    dataset_params={
        'data_dir': config.DATA_DIR,
        'images_dir': config.TRAIN_IMAGES_DIR,
        'labels_dir': config.TRAIN_LABELS_DIR,
        'classes': config.CLASSES
    },
    dataloader_params=config.DATALOADER_PARAMS
)

# 指定评估数据
val_data = coco_detection_yolo_format_val(
    dataset_params={
        'data_dir': config.DATA_DIR,
        'images_dir': config.VAL_IMAGES_DIR,
        'labels_dir': config.VAL_LABELS_DIR,
        'classes': config.CLASSES
    },
    dataloader_params=config.DATALOADER_PARAMS
)

# test_data = coco_detection_yolo_format_val(
#     dataset_params={
#         'data_dir': config.DATA_DIR,
#         'images_dir': config.TEST_IMAGES_DIR,
#         'labels_dir': config.TEST_LABELS_DIR,
#         'classes': config.CLASSES
#     },
    

dataloader_params=config.DATALOADER_PARAMS
# )
# train_data.dataset.plot()

model = models.get(config.MODEL_NAME,
                   num_classes=config.NUM_CLASSES,
                   pretrained_weights=config.PRETRAINED_WEIGHTS
                   )
train_params = {
    # ENABLING SILENT MODE
    "average_best_models":True,
    "warmup_mode": "linear_epoch_step",
    "warmup_initial_lr": 1e-6,
    "lr_warmup_epochs": 3,
    "initial_lr": 5e-4,
    "lr_mode": "cosine",
    "cosine_final_lr_ratio": 0.1,
    "optimizer": "Adam",
    "optimizer_params": {"weight_decay": 0.0001},
    "zero_weight_decay_on_bias_and_bn": True,
    "ema": True,
    "ema_params": {"decay": 0.9, "decay_type": "threshold"},
    # ONLY TRAINING FOR 10 EPOCHS FOR THIS EXAMPLE NOTEBOOK
    "max_epochs": 200,
    "mixed_precision": True,
    "loss": PPYoloELoss(
        use_static_assigner=False,
        # NOTE: num_classes needs to be defined here
        num_classes=config.NUM_CLASSES,
        reg_max=16
    ),
    "valid_metrics_list": [
        DetectionMetrics_050(
            score_thres=0.1,
            top_k_predictions=300,
            # NOTE: num_classes needs to be defined here
            num_cls=config.NUM_CLASSES,
            normalize_targets=True,
            post_prediction_callback=PPYoloEPostPredictionCallback(
                score_threshold=0.01,
                nms_top_k=1000,
                max_predictions=300,
                nms_threshold=0.7
            )
        )
    ],
    "metric_to_watch": 'mAP@0.50'
}

trainer.train(model=model,
              training_params=train_params,
              train_loader=train_data,
              valid_loader=val_data)

best_model = models.get(config.MODEL_NAME,
                        num_classes=config.NUM_CLASSES,
                        checkpoint_path=os.path.join(config.CHECKPOINT_DIR, config.EXPERIMENT_NAME, 'average_model.pth'))


3.2 开始训练

python train.py即可开始训练

结语

写的有点仓促,后续再更新,有问题评论区交流

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值