centernet下训练自己的数据

目录

一.数据准备

1.制作COCO数据集

2.计算数据集的均值方差

二.代码修改

1.新建类别

 2.加入dataset

 3.修改/src/lib/opts.py

4.修改src/lib/utils/debugger.py文件

二 训练与测试:

1训练:

2测试:

3绘制loss曲线


参照博客:

https://blog.csdn.net/weixin_42634342/article/details/97756458#commentsedit

https://blog.csdn.net/weixin_42634342/article/details/97756458#commentsedit

一.数据准备

1.制作COCO数据集

这里我用的是VOC数据集转COCO

参照博客:

https://blog.csdn.net/weixin_41765699/article/details/100124689

主要trian,val,test三个文件夹下txt转化为json

2.计算数据集的均值方差

import cv2, os, argparse
import numpy as np
from tqdm import tqdm


def main():
    dirs = '/home/zbb/CenterNet/data/plane/images'   # 修改你自己的图片路径
    img_file_names = os.listdir(dirs)
    m_list, s_list = [], []
    for img_filename in tqdm(img_file_names):
        img = cv2.imread(dirs + '/' + img_filename)
        img = img / 255.0
        m, s = cv2.meanStdDev(img)
        m_list.append(m.reshape((3,)))
        s_list.append(s.reshape((3,)))
    m_array = np.array(m_list)
    s_array = np.array(s_list)
    m = m_array.mean(axis=0, keepdims=True)
    s = s_array.mean(axis=0, keepdims=True)
    print("mean = ", m[0][::-1])
    print("std = ", s[0][::-1])

if __name__ == '__main__':
    main()

二.代码修改

1.新建类别

src/lib/datasets/dataset里面新建一个“plane. py”,文件内容照着文件夹下coco.py改成自己的

1).把COCO关键字改为Plane

2)路径格式

使用相对路径报错,改成了绝对路径

3)训练修改

修改为val,train,测试再修改回来

类别名字和类别id改成自己

 2.加入dataset

将数据集加入src/lib/datasets/dataset_factory里面

一定要记得import,否则会报你的类别未定义

 3.修改/src/lib/opts.py

将自己的数据集设为默认数据集,加入到help里面

 修改ctdet任务使用的默认数据集为新添加的数据集,如下(修改分辨率,类别数,均值,方差,数据集名字):

4.修改src/lib/utils/debugger.py文件

变成自己数据的类别和名字,前后数据集名字一定保持一致

再加上自己数据的类别,不包括背景__background__ 

二 训练与测试:

1训练:

 输入命令:

python main.py ctdet --exp_id coco_dla --batch_size 4 --master_batch 1 --lr 1.25e-4  --gpus 0,1

如果显示显存不够之类的那种错误,需要在opts.py文件中将--num_workers改成0,batch_size小

2测试:

  建立的plane.py中修改如下部分,加入if split == ‘test’:…,作用是当test时指定标签文件为之前建立的测试文件     

   运行test.py

       python test.py --exp_id coco_dla --not_prefetch_test ctdet --load_model /home/zbb/CenterNet/exp/ctdet/coco_dla/model_best.pth

结果:

其中,一般使用的是第二行,也就是IOU=0.5,全区域的AP值,其他的分别是不同IOU以及不同目标尺寸区域的结果。 

3绘制loss曲线

训练生成的日志文件一般在exp/ctdet/../../logs.txt

参照博主但是,val—loss绘制不好,先绘制total—loss

import matplotlib.pyplot as plt
import numpy as np


def plot_loss_curve(log_file):
    loss_data = open(log_file)
    all_lines = loss_data.readlines()
    print(all_lines[4].split(' '))
    # losses
    total_loss = []  # 4
    hm_loss = []  # 7
    wh_loss = []  # 10
    off_loss = []  # 13
    val_loss = []  # 19
    spend_time = []  # 16
    num_lines = len(all_lines)
    for line in range(num_lines):
        total_loss1 = all_lines[line].split(' ')[4]
        hm_loss1 = all_lines[line].split(' ')[7]
        wh_loss1 = all_lines[line].split(' ')[10]
        off_loss1 = all_lines[line].split(' ')[13]
        #val_loss1 = all_lines[line].split(' ')[19]
        spend_time1 = all_lines[line].split(' ')[16]
        print(total_loss1)
        print(spend_time1)

        total_loss.append(float(total_loss1))
        #val_loss.append(float(val_loss1))
        hm_loss.append(float(hm_loss1))
        wh_loss.append(float(wh_loss1))
        off_loss.append(float(off_loss1))
        spend_time.append(float(spend_time1))
    return total_loss

if __name__ == '__main__':
    # 标准图形绘制
    # sns.set()
    loss_res18 = plot_loss_curve(
        '/home/zbb/CenterNet/exp/ctdet/coco_dla/logs_2019-10-17-15-41/log.txt')  # 读取训练时生成的日志文件
    fig = plt.figure(figsize=(10, 4))
    ax = fig.add_subplot(111)
    ax.plot(range(len(loss_res18)), loss_res18, 'c', label='building', linewidth=1)  # 这个label是图线自己的标签;

    # ax.set_xlim([0, 800])                                      # 设置刻度;
    # ax.set_xticks(range(0, 500, 100))                          # 设置显示的刻度;
    # ax.set_yticklabels(['jan', 'feb', 'mar'])                  # 设置刻度标签;
    ax.set_xlabel('epochs')  # 设置坐标轴标签;
    ax.set_ylabel('loss_value')
    ax.text(8750, 20, "plane", color='red')  # 加入文本
    ax.set_title('loss_of_CenterNet')
    ax.legend(loc='best')  # 将图例摆放在不遮挡图线的位置即可
    ax.grid()  # 添加网格
    plt.savefig('/home/zbb/CenterNet/loss_of_CenterNet.png')  # 保存文件到指定文件夹
    plt.show()

total——loss结果图:

要使用CenterNet训练自己的数据集,你需要进行以下步骤: 1. 删除之前训练过程中生成的缓存文件。如果你之前使用了coco数据集测试了模型,需要删除CenterNet-master/cache/coco_minival2014.pkl文件。这是因为在第一次运行时,代码会将coco数据集的instances转换为模型所需的格式,并在下一次使用时直接读取。如果你没有训练过coco数据集,可以忽略这一步。\[1\] 2. 修改参数。根据你要训练的模型选择对应的文件,比如models/CenterNet-52.py或models/CenterNet-104.py。在文件中找到第132行,将out_dim的值从80修改为你自己数据集的类别数目。\[2\] 3. 将数据集分成训练集和验证集。将图片文件夹重命名为trainval2014和minival2014,并放置在CenterNet-master/data/coco/images目录下。将对应的json文件命名为instances_trainval2014.json和instances_minival2014.json,并放置在CenterNet-master/data/coco/annotations目录下。\[3\] 完成以上步骤后,你就可以使用CenterNet训练自己的数据集了。 #### 引用[.reference_title] - *1* *2* *3* [CenterNet 训练自己的数据集](https://blog.csdn.net/surserrr/article/details/100153886)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值