YOLO-zht训练-未完待续

这里写自定义目录标题

1.修改训练代码

import os
import math
import time
import argparse
import numpy as np
from tqdm import tqdm
from numpy.testing._private.utils import print_assert_equal

import torch
from torch import optim
from torch.utils.data import dataset
from numpy.core.fromnumeric import shape

from torchsummary import summary

import utils.loss
import utils.utils
import utils.datasets
import model.detector

if __name__ == '__main__':
    # 指定训练配置文件
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='./data/coco.data',
                        help='Specify training profile *.data')
    opt = parser.parse_args()
    cfg = utils.utils.load_datafile(opt.data)

    print("训练配置:")
    print(cfg)

    train_dataset = utils.datasets.TensorDataset(cfg["train"], cfg["width"], cfg["height"], imgaug = True)
    val_dataset = utils.datasets.TensorDataset(cfg["val"], cfg["width"], cfg["height"], imgaug = False)

    batch_size = int(cfg["batch_size"] / cfg["subdivisions"])
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print(nw)
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   collate_fn=utils.datasets.collate_fn,
                                                   num_workers=nw,
                                                   pin_memory=True,
                                                   drop_last=True,
                                                   persistent_workers=True
                                                   )
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 collate_fn=utils.datasets.collate_fn,
                                                 num_workers=nw,
                                                 pin_memory=True,
                                                 drop_last=False,
                                                 persistent_workers=True
                                                 )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    load_param = False
    premodel_path = cfg["pre_weights"]
    if premodel_path != None and os.path.exists(premodel_path):
        load_param = True

    model = model.detector.Detector(cfg["classes"], cfg["anchor_num"], load_param).to(device)
    summary(model, input_size=(3, cfg["height"], cfg["width"]))

    # 加载预训练模型参数
    if load_param == True:
        model_dict = model.state_dict()
        pretrained_dict = torch.load(premodel_path, map_location=device)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        print("Load finefune model param: %s" % premodel_path)
    else:
        print("Initialize weights: model/backbone/backbone.pth")

    optimizer = optim.SGD(params=model.parameters(),
                          lr=cfg["learning_rate"],
                          momentum=0.949,
                          weight_decay=0.0005,
                          )

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=cfg["steps"],
                                               gamma=0.1)

    print('Starting training for %g epochs...' % cfg["epochs"])

    batch_num = 0
    for epoch in range(cfg["epochs"]):
        model.train()
        pbar = tqdm(train_dataloader)

        for imgs, targets in pbar:
            imgs = imgs.to(device).float() / 255.0
            targets = targets.to(device)
            preds = model(imgs)
            iou_loss, obj_loss, cls_loss, total_loss = utils.loss.compute_loss(preds, targets, cfg, device)

            total_loss.backward()

            for g in optimizer.param_groups:
                warmup_num =  5 * len(train_dataloader)
                if batch_num <= warmup_num:
                    scale = math.pow(batch_num/warmup_num, 4)
                    g['lr'] = cfg["learning_rate"] * scale
                    
                lr = g["lr"]

            if batch_num % cfg["subdivisions"] == 0:
                optimizer.step()
                optimizer.zero_grad()

            info = "Epoch:%d LR:%f CIou:%f Obj:%f Cls:%f Total:%f" % (
                    epoch, lr, iou_loss, obj_loss, cls_loss, total_loss)
            pbar.set_description(info)

            batch_num += 1

        if epoch % 10 == 0 and epoch > 0:
            model.eval()
            print("computer mAP...")
            _, _, AP, _ = utils.utils.evaluation(val_dataloader, cfg, model, device)
            print("computer PR...")
            precision, recall, _, f1 = utils.utils.evaluation(val_dataloader, cfg, model, device, 0.3)
            print("Precision:%f Recall:%f AP:%f F1:%f"%(precision, recall, AP, f1))

            torch.save(model.state_dict(), "weights/%s-%d-epoch-%fap-model.pth" %
                      (cfg["model_name"], epoch, AP))

        scheduler.step()


2.制作自己数据集

import cv2

cap=cv2.VideoCapture(1)
i=1
while(cap.isOpened()):
    ret,f=cap.read()
    c=cv2.waitKey(1)
    cv2.imshow("f",f)
    if c==27:
        break
    elif c==ord("q"):
        print(i)
        path="./data/"+str(i)+".jpg"
        cv2.imwrite(path,f)
        i+=1
    else:
        pass

cap.release()
cv2.destroyAllWindows()

3.采用labelimg打标签

pip install labelimg

4.数据处理

import xml.etree.ElementTree as et
import os

classes=["a","b"]

def find_xml(path):
    '''
    找到一个文件夹下所有的xml文件,返回一个装这这些文件路径名的所有文件的列表
    '''
    file_list=os.listdir(path)
    xml_list=[]
    for file in file_list:
        if file.endswith(".xml"):
            xml_list.append(path+"/"+file)

    return xml_list

def xml2txt(path,x=640,y=480):
    xml_list=find_xml(path)
    for xml in xml_list:
        file=xml[:-4]
        txt=file+".txt"
        #xml文件解析器,将xml文件解析成元素树
        tree=et.parse(xml)
        #拿到树的根
        root=tree.getroot()
        with open(txt,'w') as f:
            #root.iter创建迭代器,寻找所有object的节点
            for obj in root.iter('object'):
                #按找标记名寻找匹配的第一个元素,text返回字符串
                cls = obj.find('name').text
                if cls not in classes:
                    continue
                cls_id = classes.index(cls)
                xmlbox = obj.find('bndbox')
                b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text),
                     int(xmlbox.find('ymax').text))
                w=float(b[2]-b[0])
                h=float(b[3]-b[1])
                center_x=float(b[0]+w/2)
                center_y=float(b[1]+h/2)
                d=[center_x/x,center_y/y,w/x,h/y]
                f.write(str(cls_id)+" "+" ".join([str(a) for a in d])+"\n")


if __name__ == '__main__':
    xml2txt("./xml")

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值