代码链接 GitHub - facebookresearch/detr: End-to-End Object Detection with Transformers
论文链接https://arxiv.org/abs/2005.12872
使用GPU复现。关于如何配置GPU,看文章CUDA和CUDNN安装教程
1.环境配置(利用anaconda)
打开Anaconda Prompt 创建并激活环境detr-main,这里要用到的只有创建与激活
conda create -n 环境名 #创建
conda activate 环境名 #激活进入环境
conda deactivate #退出环境
输入安装
conda install -c pytorch pytorch torchvision
conda install cython scipy
pip install git+https://github.com/cocodataset/panopticapi.git
#最后一条没有安装git要先conda install git
2.准备数据
数据集下载网站:http://cocodataset.org.
有时候网站不好用,网站下载速度较慢,可以选择百度网盘下载,一搜就能搜到。
3.下载预训练权重文件,生成pth文件
链接:https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth
新建py文件,mydataset.py,使用如下代码,修改num_classes为自己的类别数+1
import torch
pretrained_weights = torch.load('detr-r50-e632da11.pth')
#NWPU数据集,10类
num_class = 11 #类别数+1,1为背景
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1, 256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights, "detr-r50_%d.pth"%num_class)
生成pth文件 detr-r50-e632da11.pth
4.指定数据路径
指定pth文件
5.参数修改
修改models/detr.py文件,build()函数中,可以将红框部分的代码都注释掉,直接设置num_classes为自己的类别数+1
6.修改参数,进行训练
main.py 修改参数并运行
参考