YOLO-NAS,一种优于YOLOv8的图像检测和分割的模型


环境介绍

环境介绍:
前提你已经装上英伟达的显卡驱动和MiniConda,这里就不再赘述.下面是博主自己的环境介绍

ubuntu22.04
python3.10.12
cuda11.8

安装pytorch
在这里插入图片描述

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

这里博主贴出自己的依赖版本

absl-py==2.0.0
alabaster==0.7.13
antlr4-python3-runtime==4.9.3
appdirs==1.4.4
arabic-reshaper==3.0.0
asn1crypto==1.5.1
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
astunparse==1.6.3
attrs==23.1.0
Babel==2.13.1
boto3==1.29.6
botocore==1.32.7
build==1.0.3
cachetools==5.3.2
certifi==2022.12.7
cffi==1.16.0
charset-normalizer==2.1.1
click==8.1.7
coloredlogs==15.0.1
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1691044910542/work
contourpy==1.2.0
coverage==5.3.1
cryptography==41.0.5
cssselect2==0.7.0
cycler==0.12.1
data-gradients==0.3.1
debugpy @ file:///croot/debugpy_1690905042057/work
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
Deprecated==1.2.14
docutils==0.17.1
einops==0.3.2
entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1700579780973/work
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
fast-histogram==0.12
filelock==3.9.0
flatbuffers==23.5.26
fonttools==4.44.0
fsspec==2023.4.0
future==0.18.3
gast==0.5.4
google-auth==2.23.4
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.59.2
h5py==3.10.0
html5lib==1.1
humanfriendly==10.0
hydra-core==1.3.2
idna==3.4
imagededup==0.3.2
imagesize==1.4.1
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1698244021190/work
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1701831663892/work
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
json-tricks==3.16.1
jsonschema==4.19.2
jsonschema-specifications==2023.7.1
jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1698673647019/work
keras==2.8.0
Keras-Preprocessing==1.1.2
kiwisolver==1.4.5
libclang==16.0.6
lxml==4.9.3
Mako==1.3.0
Markdown==3.5.1
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.8.1
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
mdurl==0.1.1
mpmath==1.3.0
nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1697083700168/work
networkx==3.0
numpy==1.23.0
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
oauthlib==3.2.2
omegaconf==2.3.0
onnx==1.14.1
onnx-graphsurgeon==0.3.27
onnx-simplifier==0.4.35
onnxoptimizer==0.3.8
onnxruntime==1.16.0
onnxruntime-gpu==1.16.3
onnxsim==0.4.35
opencv-python==4.8.1.78
opt-einsum==3.3.0
oscrypto==1.3.0
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1696202382185/work
pandas==2.1.2
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
Pillow==9.3.0
pip-tools==7.3.0
platformdirs==3.11.0
prettytable==3.9.0
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work
protobuf==3.20.3
psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycocotools==2.0.6
pycparser==2.21
pycuda==2020.1
pyDeprecate==0.3.2
Pygments==2.16.1
pyHanko==0.20.1
pyhanko-certvalidator==0.24.1
pyparsing==2.4.5
pypdf==3.17.0
pypng==0.20220715.0
pyproject_hooks==1.0.0
python-bidi==0.4.2
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
pytools==2023.1.1
pytorch-quantization==2.1.2
pytz==2023.3.post1
PyWavelets==1.4.1
PyYAML==6.0.1
pyzmq @ file:///croot/pyzmq_1686601365461/work
qrcode==7.4.2
rapidfuzz==3.5.2
referencing==0.30.2
reportlab==3.6.13
requests==2.28.1
requests-oauthlib==1.3.1
rich==13.6.0
rknn-toolkit2 @ file:///root/liaosc/export/rknn_toolkit2-1.6.0%2B81f21f4d-cp310-cp310-linux_x86_64.whl#sha256=559ed24ea17c678ccb7ebb352d0f3f16a4e2fb604b359d94d035c7c720814e8e
rpds-py==0.12.0
rsa==4.8
ruamel.yaml==0.18.5
ruamel.yaml.clib==0.2.8
s3transfer==0.7.0
scikit-learn==1.3.2
scipy==1.11.3
seaborn==0.13.0
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
snowballstemmer==2.2.0
Sphinx==4.0.3
sphinx-glpi-theme==0.4.1
sphinx-rtd-theme==1.3.0
sphinxcontrib-applehelp==1.0.4
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.1
sphinxcontrib-jquery==4.1
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
stringcase==1.2.0
-e git+https://github.com/Deci-AI/super-gradients.git@bbcfc105fbcd0dd5794a9874ff14277aecf8b263#egg=super_gradients
svglib==1.5.1
sympy==1.12
tensorboard==2.8.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.8.0
tensorflow-io-gcs-filesystem==0.35.0
tensorrt @ file:///root/liaosc/project/TensorRT-8.5.3.1/python/tensorrt-8.5.3.1-cp310-none-linux_x86_64.whl#sha256=2264a978f8f5b98ec9d98464e3f8a015298ba967e894fc252042fecf4ceeded7
termcolor==1.1.0
tf-estimator-nightly==2.8.0.dev2021122109
threadpoolctl==3.2.0
tinycss2==1.2.1
tomli==2.0.1
torch==1.13.1
torch2trt==0.4.0
torchaudio==2.1.1+cu118
torchmetrics==0.8.0
torchvision==0.16.1+cu118
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827254365/work
tqdm==4.66.1
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1701095650114/work
treelib==1.6.1
triton==2.1.0
typing_extensions==4.4.0
tzdata==2023.3
tzlocal==5.2
uritools==4.0.2
urllib3==1.26.13
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1700607916581/work
webencodings==0.5.1
Werkzeug==3.0.1
wrapt==1.15.0
xhtml2pdf==0.2.11


一、什么是YOLO-NAS

由 Deci 的神经架构搜索技术生成的下一代目标检测基础模型
Deci很高兴地宣布发布新的目标检测模型YOLO-NAS,这是目标检测领域的游戏规则改变者,提供卓越的实时目标检测功能和生产就绪性能。Deci 的使命是为 AI 团队提供工具,以消除开发障碍并更快地获得高效的推理性能。
YOLO-NAS效果对比图
全新 YOLO-NAS 提供最先进的 (SOTA) 性能和无与伦比的精度速度性能,优于 YOLOv5、YOLOv6、YOLOv7 和 YOLOv8 等其他型号。

Deci专有的神经架构搜索技术AutoNAC™生成了YOLO-NAS模型。AutoNAC™ 引擎允许您输入任何任务、数据特征(不需要访问数据)、推理环境和性能目标,然后指导您找到最佳架构,为您的特定应用程序提供准确性和推理速度之间的最佳平衡。除了数据和硬件感知之外,AutoNAC 引擎还考虑推理堆栈中的其他组件,包括编译器和量化。

就纯数字而言,YOLO-NAS 比 YOLOv8 和 YOLOv7 的等效变体更准确 ~0.5 mAP 点,速度快 10-20%。

二、YOLO-NAS快速入门

import super_gradients

yolo_nas = super_gradients.training.models.get("yolo_nas_s", pretrained_weights="coco").cuda()
model_predictions  = yolo_nas.predict("https://deci-pretrained-models.s3.amazonaws.com/sample_images/beatles-abbeyroad.jpg").show()

prediction = model_predictions[0].prediction        # One prediction per image - Here we work with 1 image so we get the first.

bboxes = prediction.bboxes_xyxy                     # [[Xmin,Ymin,Xmax,Ymax],..] list of all annotation(s) for detected object(s) 
bboxes = prediction.bboxes_xyxy                     # [[Xmin,Ymin,Xmax,Ymax],..] list of all annotation(s) for detected object(s) 
class_names = prediction.class_names                # ['Class1', 'Class2', ...] List of the class names
class_name_indexes = prediction.labels.astype(int)  # [2, 3, 1, 1, 2, ....] Index of each detected object in class_names(corresponding to each bounding box)
confidences =  prediction.confidence.astype(float)  # [0.3, 0.1, 0.9, ...] Confidence value(s) in float for each bounding boxes

效果图如下

效果图

三、YOLO-NAS训练自己的数据集

首先你得准备训练集和验证集,这里数据集搭建自己选择,注意数据集格式一定是YOLO格式的。这里搭数据集建可以自己百度搜索
下面给出代码

from super_gradients.training import Trainer
from super_gradients.training import training_hyperparams
from torch.utils.data import Dataset, DataLoader
from super_gradients.training.utils.collate_fn.detection_collate_fn import (
    DetectionCollateFN,
)
from super_gradients.training.datasets import YoloDarknetFormatDetectionDataset
from super_gradients.training import Trainer, dataloaders, models
from super_gradients.training import models
from super_gradients.training.datasets.data_formats.default_formats import LABEL_CXCYWH
from torchvision.transforms import Resize
from super_gradients.training.losses.ppyolo_loss import PPYoloELoss
from super_gradients.training.processing.processing import ComposeProcessing

from super_gradients.training.transforms.transforms import (
    DetectionMosaic,
    DetectionRandomAffine,
    DetectionMixup,
    DetectionHorizontalFlip,
    DetectionRGB2BGR,
    DetectionHSV,
    DetectionPaddedRescale,
    DetectionStandardize,
    DetectionTargetsFormatTransform,
    DetectionPadToSize,
    DetectionImagePermute,
)

from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import (
    PPYoloEPostPredictionCallback,
)


input_dim = (640, 640)
epochs = 200

CHECKPOINT_DIR = "/root/project/notebook_ckpts/"#模型和tensorboard以及日志保存路径
trainer = Trainer(experiment_name="disc_train_yolo_nas_s_new", ckpt_root_dir=CHECKPOINT_DIR)

num_classes = 5
classes = [str(i) for i in range(num_classes)]
base_data_dir = "/root/project/data/disc"#数据集路径

train_transforms = [
    # TODO enable
    DetectionMosaic(input_dim=input_dim,prob=.2),
    DetectionRandomAffine(
        degrees=10.0,
        translate=0.1,
        scales=[0.1, 2],
        shear=2.0,
        target_size=input_dim,
        filter_box_candidates=True,
        wh_thr=2,
        area_thr=0.1,
        ar_thr=20,
    ),
    # DetectionRGB2BGR(prob=0.5),
    DetectionMixup(
        input_dim=input_dim, mixup_scale=[0.5, 1.5], prob=1.0, flip_prob=0.5
    ),
    DetectionHSV(prob=1.0, hgain=5, sgain=30, vgain=30),
    DetectionHorizontalFlip(prob=0.5),
    DetectionPaddedRescale(input_dim=input_dim),
    DetectionStandardize(max_value=255.0),
    DetectionTargetsFormatTransform(input_dim=input_dim, output_format=LABEL_CXCYWH),
]

val_transforms = [
    # DetectionRGB2BGR(prob=0.5),
    DetectionPadToSize(output_size=input_dim, pad_value=114),
    DetectionStandardize(max_value=255.0),
    DetectionImagePermute(),
    DetectionTargetsFormatTransform(input_dim=input_dim, output_format=LABEL_CXCYWH),
]

train_dataloader = dataloaders.get(
    name="coco_detection_yolo_format_train",
    dataset_params={
        "data_dir": f"{base_data_dir}/train",
        "classes": classes,
        "images_dir": f"{base_data_dir}/train/images",
        "labels_dir": f"{base_data_dir}/train/labels",
        "transforms": train_transforms,
    },
    dataloader_params={"num_workers": 2, "batch_size": 8, "drop_last": True},
)

val_dataloader = dataloaders.get(
    name="coco_detection_yolo_format_val",
    dataset_params={
        "data_dir": f"{base_data_dir}/validation/",
        "classes": classes,
        "images_dir": f"{base_data_dir}/validation/images",
        "labels_dir": f"{base_data_dir}/validation/labels",
        "transforms": val_transforms,
    },
    dataloader_params={"num_workers": 2, "batch_size": 8, "drop_last": True},
)

model = models.get("yolo_nas_m", pretrained_weights="coco", num_classes=num_classes)
# print(model.num_classes)
train_params = training_hyperparams.get("coco2017_yolo_nas_s")
train_params["max_epochs"] = epochs
train_params['initial_lr'] = 0.0005
train_params["loss"] = PPYoloELoss(num_classes=num_classes)
trainer.train(
    model=model,
    training_params=train_params,
    train_loader=train_dataloader,
    valid_loader=val_dataloader,
)

大家数据集对应路径一定要填对哦!!!数据集目录大概长成这样
数据集文件夹格式

家人们喜欢的一定要点个赞再走啊。如果赞多的话后面会出推理,量化和加速的代码!哈哈哈哈哈~~~~~

  • 11
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值