RT-DETR rtdetr-r18 - ultralytics - YOLOv8版本训练自己的数据集

本文介绍了如何在Python环境中安装和配置RT-DETR模型,包括YOLOv8环境设置、数据集准备、以及如何使用私有GitHub仓库进行训练。同时,也提及了如何升级PyTorch到1.9.0版本以支持模型运行。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

0 源码与相关参考

源码请从b站up主【魔傀面具】获取:https://space.bilibili.com/286900343?spm_id_from=333.337.0.0

以下代码为私有

github:

https://github.com/Whiffe/ultralytics-RT-DETR

码云

https://github.com/Whiffe/ultralytics-RT-DETR

参考:b站魔傀面具
https://space.bilibili.com/286900343?spm_id_from=333.337.0.0

1 安装

注意,开始之前,需要配置好YOLOv8的环境

conda create -n rtdetr python=3.8
conda activate rtdetr
git clone https://github.com/Whiffe/ultralytics-RT-DETR
# git clone https://github.com/Whiffe/ultralytics-RT-DETR

cd ultralytics-RT-DETR

pip install ultralytics

python setup.py develop

pip install timm thop efficientnet_pytorch einops grad-cam dill -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install -U openmim

mim install mmengine

mim install "mmcv>=2.0.0"

pip install psutil

2 数据集准备

moon.yaml

path: D:\mySearch\RT-detr\ultralytics-RT-DETR\dataset
train: ./yolo_behavior_Dataset_all2/images/train
val: ./yolo_behavior_Dataset_all2/images/val

# number of classes
nc: 5

# class names
names: [ 'weigh','height measure','drop ball','size measure','record']

3 训练 rtdetr-r18

python train.py

在这里插入图片描述

import warnings
warnings.filterwarnings('ignore')
from ultralytics import RTDETR

if __name__ == '__main__':
    model = RTDETR('ultralytics/cfg/models/rt-detr/rtdetr-r18.yaml')
    # model.load('') # loading pretrain weights
    model.train(data='dataset/moon.yaml',
                cache=False,
                imgsz=640,
                epochs=100,
                batch=4,
                workers=4,
                device='0',
                # resume='', # last.pt path
                project='runs/train',
                name='exp',
                )

在这里插入图片描述

4 pytorch 版本升级

pytorch离线安装网站:

https://download.pytorch.org/whl/torch_stable.html

pytorch需要1.9.0及以上版本

cu111/torch-1.9.0%2Bcu111-cp38-cp38-linux_x86_64.whl
https://download.pytorch.org/whl/cu111/torch-1.9.0%2Bcu111-cp38-cp38-linux_x86_64.whl

在这里插入图片描述
卸载

pip uninstall torch

安装

pip install torch-1.9.0+cu111-cp38-cp38-linux_x86_64.whl
03-08
### RTDETR 技术概述 RTDETR (Real-Time Detection Transformer) 是一种专为实时目标检测设计的高效架构。该方法融合了Transformer的强大建模能力与卷积神经网络(CNN)的优势,在保持高精度的同时实现了极高的推理速度[^1]。 #### 架构特点 - **轻量化骨干网**:采用优化后的EfficientNet作为特征提取器,通过减少计算量和参数数量来提高运行效率。 - **多尺度特征聚合机制**:引入FPN(Feature Pyramid Network),使得不同层次的信息能够更好地融合在一起,增强了模型对于大小物体的识别效果。 - **解耦头结构**:将分类和回归任务分离处理,降低了训练难度并提升了最终性能表现。 ```python from rt detr import build_model model = build_model( backbone='efficientnet_b0', num_classes=80, pretrained=True ) ``` #### 安装依赖库 为了顺利部署RTDETR项目,建议先安装所需的Python包: ```bash pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple ``` 此命令会从清华大学镜像源下载必要的软件包,加快安装过程[^2]。 #### 应用实例展示 假设已经完成数据预处理工作,则可以直接调用`infer_image()`函数来进行单张图片的目标检测预测,并利用SAHI框架中的`visualize_object_predictions()`工具直观呈现结果。 ```python import cv2 from sahi.utils.file import save_json from sahi.slicing import slice_image from sahi.predict import get_sliced_prediction, visualize_object_predictions image_path = 'path/to/your/image.jpg' result = model.infer_image(image_path) visualization_result = visualize_object_predictions( image=cv2.imread(image_path), object_prediction_list=result["object_prediction_list"], output_dir="output/", file_name="prediction" ) ``` 上述代码片段展示了如何加载待测图像文件并通过已训练好的RTDETR模型获取其上的对象位置信息;随后借助于SAHI提供的绘图接口保存带有标注框的结果图至指定目录下[^3]。
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CSPhD-winston-杨帆

给我饭钱

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值