Roboflow训练自己样本训练模型-YOLOnas

环境安装

参考链接:https://blog.roboflow.com/yolo-nas-how-to-train-on-custom-dataset/
pip install super-gradients==3.1.1
pip install roboflow
pip install supervision

Load YOLO-NAS Model


```python
import torch
from super_gradients.training import models

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_ARCH = 'yolo_nas_l'
#            'yolo_nas_m'
#            'yolo_nas_s'

model = models.get(MODEL_ARCH, pretrained_weights="coco").to(DEVICE)
CONFIDENCE_TRESHOLD = 0.35

result = list(model.predict(image, conf=CONFIDENCE_TRESHOLD))[0]
import roboflow
from roboflow import Roboflow

roboflow.login()
#  https://app.roboflow.com/auth-cli获取token

rf = Roboflow()
project = rf.workspace(WORKSPACE_ID).project(PROJECT_ID)
dataset = project.version(PROJECT_VERSION).download("yolov5")

在这里插入图片描述
在这里插入图片描述

Select Hyperparameter Values

MODEL_ARCH = 'yolo_nas_l'
BATCH_SIZE = 8
MAX_EPOCHS = 25
CHECKPOINT_DIR = f'{HOME}/checkpoints'
EXPERIMENT_NAME = project.name.lower().replace(" ", "_")
LOCATION = dataset.location
CLASSES = sorted(project.classes.keys())

dataset_params = {
    'data_dir': LOCATION,
    'train_images_dir':'train/images',
    'train_labels_dir':'train/labels',
    'val_images_dir':'valid/images',
    'val_labels_dir':'valid/labels',
    'test_images_dir':'test/images',
    'test_labels_dir':'test/labels',
    'classes': CLASSES
}

from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train, coco_detection_yolo_format_val)

train_data = coco_detection_yolo_format_train(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['train_images_dir'],
        'labels_dir': dataset_params['train_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size': BATCH_SIZE,
        'num_workers': 2
    }
)

val_data = coco_detection_yolo_format_val(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['val_images_dir'],
        'labels_dir': dataset_params['val_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size': BATCH_SIZE,
        'num_workers': 2
    }
)

Train a Custom YOLO-NAS Model

trainer.train(
    model=model, 
    training_params=train_params, 
    train_loader=train_data, 
    valid_loader=val_data
)

Evaluating the Custom YOLO-NAS Model

trainer.test(
    model=best_model,
    test_loader=test_data,
    test_metrics_list=DetectionMetrics_050(
        score_thres=0.1, 
        top_k_predictions=300, 
        num_cls=len(dataset_params['classes']), 
        normalize_targets=True, 
        post_prediction_callback=PPYoloEPostPredictionCallback(
            score_threshold=0.01, 
            nms_top_k=1000, 
            max_predictions=300,                                                                              
            nms_threshold=0.7
        )
    )
)

本地模型识别

模型加载,识别返回的结果是ImageDetectionPrediction对象的列表

import torch
from super_gradients.training import models
MODEL_ARCH="yolo_nas_l"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
best_model = models.get(
    MODEL_ARCH,
    #num_classes=len(dataset_params['classes'])
    num_classes=1,
    checkpoint_path=r"E:\***\test\RUN_20240520_092136_000449\average_model.pth"#模型训练的模型地址
).to(DEVICE)
CONFIDENCE_TRESHOLD = 0.35
image_path=r"C:\Users\dell\Pictures\tmp"#里面是测试的数据照片
result = list(best_model.predict(image_path, conf=CONFIDENCE_TRESHOLD))

在这里插入图片描述

可视化训练结果

import cv2
import matplotlib.pyplot as plt
import numpy as np

# 假设img是你的图片数据,predictions是你的预测结果
img =result[0].image
# 创建一个副本
img_copy = np.copy(img)

# 将 BGR 图片转换为 RGB 图片用于显示
# img_copy = cv2.cvtColor(img_copy, cv2.COLOR_BGR2RGB)

# 获取预测结果中的箱子数据
bboxes = result[0].prediction.bboxes_xyxy

for bbox in bboxes:
    # 获取bbox的坐标
    start_point = (int(bbox[0]), int(bbox[1]))
    end_point = (int(bbox[2]), int(bbox[3]))
    # 绘制矩形框到图片上
    img_copy = cv2.rectangle(img_copy, start_point, end_point, (0, 0, 255), 2)

# 利用 matplotlib 来显示图片
plt.imshow(img_copy)
plt.show()
  • 10
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值