用YOLOV3,从0开始训练自己的数据集+测试

假设你已经把数据准备好了,也安装好了darknet,如果还未安装好,请看我的另一篇博客:https://blog.csdn.net/qq_32473523/article/details/107252345

假设前面一切准备妥当,那么我们将从头开始训练自己的数据集。

注意所有的txt文件,不要有多余的换行,不然读数据的时候可能问题(txt文件不自己改就不会有问题)

part1.数据部分

1.先将准备好的数据放入darknet中

在darknet文件夹下新建一个存储数据的文件夹,然后数据按照Pascal VOC DATA的格式存放,(我新建的文件夹名字叫datapig)如下图所示:

说明:新建好datapig文件夹后,在该文件夹下只需要建好Annotations(用来存放标注好的图片的XML文件的)、ImageSets、JPEGImages(用来存放标注的图片)三个文件夹,其余的文件夹忽视掉,ImageSets文件夹下再建一个Main文件夹,用来存放trainval.txt和test.txt。

2.生成上一步所需要的trainval.txt和test.txt文件

import os
import random

trainval_percent = 0.8   # trainval占总数的比例
# train_percent = 0.5   # train占trainval的比例
xmlfilepath = "/home/ubantu/darknet/datapig/Annotations/"
txtsavepath = "/home/ubantu/darknet/datapig/ImageSets/Main/"
total_xml = os.listdir(xmlfilepath)    # 列举当前目录下所有的文件,返回的是列表类型

num = len(total_xml)                   # 获得总的文件个数 这里就是xml文件的个数
l = range(num)                        # 生成一个整数列表 如num=10,则为[1,2,3,4,...9]
tv = int(num * trainval_percent)
# tr = int(tv * train_percent)       #我直接将trainval作为训练集 故这里注释掉 你们自己的就看情况
trainval = random.sample(l, tv)    # 多个字符中生成指定数量的随机字符
# train = random.sample(trainval, tr) # 同理 注释掉
print(len(set(trainval)))
ftrainval = open(txtsavepath + 'trainval.txt', 'w')
ftest = open(txtsavepath + 'test.txt', 'w')
ftest = open(txtsavepath + 'a.txt', 'w')
# ftrain = open(txtsavepath + r'\train.txt', 'w')
# fval = open(txtsavepath + r'\val.txt', 'w')
count1,count2 = 0,0
for i in l:
    name = total_xml[i][:-4] + '\n'
    if i in trainval:
        ftrainval.write(name)
        count1+=1
        # 同理 注释掉
        # if i in train: 
        #     ftrain.write(name)
        # else:
        #     fval.write(name)
    else:
        ftest.write(name)
        count2+=1
print(count1,count2)

生成的txt文件只有图片名的前缀,如下:

3.为图片数据集生成txt的标签文件 

下面是我自己修改的代码,源代码在darknet/scripts/目录下,由voc_lable.py修改得到。

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join

#sets=[('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test')]

#classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", #"diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", #"tvmonitor"]
classes = ["pig"]

def convert(size, box):
    dw = 1./size[0]
    dh = 1./size[1]
    x = (box[0] + box[1])/2.0
    y = (box[2] + box[3])/2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x*dw
    w = w*dw
    y = y*dh
    h = h*dh
    return (x,y,w,h)

def convert_annotation(name):
    in_file = open('Annotations/%s.xml'%(name))
    out_file = open('labels/%s.txt'%(name), 'w')
    tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
    #print(w,h)
    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        #print(cls,type(cls))
        #cls = int(cls)
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        #print(name)
        bb = convert((w,h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')

wd = getcwd()

"""ImageSets/Main文件夹下的trainval.txt含有的只是相应图片的名字,而在datapig文件下的trainval.txt是图片的地址"""
if not os.path.exists('labels/'):
        os.makedirs('labels/')
image_ids = open('ImageSets/Main/trainval.txt').read().strip().split()
#print(image_ids)
list_file = open('trainval.txt', 'w')
for image_id in image_ids:
    list_file.write('%s/JPEGImages/%s.jpg\n'%(wd,image_id))
    convert_annotation(image_id)
list_file.close()
#生成test.txt文件
image_idst = open('ImageSets/Main/test.txt').read().strip().split()
list_file = open('test.txt', 'w')
for image_id in image_idst:
    list_file.write('%s/JPEGImages/%s.jpg\n'%(wd,image_id))
    convert_annotation(image_id)
list_file.close()

运行结束后在刚才新建的数据文件夹(我的是datapig文件夹)下生成一个trainval.txt和一个test.txt文件,注意其与ImageSets/Main文件夹下的trainval.txt的区别。 

part2.配置部分:

1.打开data文件,复制一份voc.names,把其中的name改成自己的

2..打开cfg文件,复制一个voc.data,自定义文件名,运行的时候会用这个文件,如下,我把名称改成pig.data。classes就是类别的数目,train就是上一步生成的trainval.txt的路径,————————————

names的值就是上一步的pig.names所在的路径。backup就是存储训练时的权重的,我自己在darknet文件夹下新建了一个backpig文件夹,用来存储权重。

 3.打开cfg文件夹,复制一份yolov3-voc.cfg,改成自己想要的名字,然后依次修改如下内容:

第一处training这里,记得把test的注释掉,test的时候记得把train注释掉,不然会有很多问题。

[net]
# Testing
#batch=1
#subdivisions=1
# Training
batch=64
subdivisions=16
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1

第二处,该配置文件中有三处[yolo] ,每一处[yolo]前的filters改成3*(5+classes数目),[yolo]中的classes改成类别数。

[convolutional]
size=1
stride=1
pad=1
#filters=75
filters=18#改这里
activation=linear

[yolo]
mask = 6,7,8
anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
classes=1#改这里
num=9
jitter=.3
ignore_thresh = .5
truth_thresh = 1
random=1

part3:开始训练:

如果之前没下载初始权重,那么可以通过下面代码在darknet目录下下载权重

wget https://pjreddie.com/media/files/darknet53.conv.74

通过下面语句进行训练,注意要切换darknet目录下执行。

 ./darknet detector train cfg/pig.data cfg/yolov3-pig.cfg darknet53.conv.74 -gpus 0,1,2,3

Part4:单张测试: 

进行测试前记得修改cfg文件夹下的yolov3-pig.cfg,改为测试模式,改完部分如下

[net]
# Testing 测试阶段一定要调成这样,不然预测不出来
batch=1
subdivisions=1
# Training
#batch=64
#subdivisions=16
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
./darknet detector test  cfg/pig.data cfg/yolov3-pig.cfg yolov3-pig_final.weights data/pig1.jpg

Part5:多张一起测试,测试mAP 

1.在terminal中执行下列语句,在darknet/results文件夹下生成相应的txt文件

 ./darknet detector valid cfg/pig.data cfg/yolov3-pig.cfg backuppig/yolov3-pig_final.weights -out "" -gpu 0 -thresh .5

2.计算mAP

darknet目录下新建一个voc_eval.py的文件,内容如下:

import xml.etree.ElementTree as ET
import os
#import cPickle
import _pickle as cPickle
import numpy as np

def parse_rec(filename):
    """ Parse a PASCAL VOC xml file """
    tree = ET.parse(filename)
    objects = []
    for obj in tree.findall('object'):
        obj_struct = {}
        obj_struct['name'] = obj.find('name').text
        obj_struct['pose'] = obj.find('pose').text
        obj_struct['truncated'] = int(obj.find('truncated').text)
        obj_struct['difficult'] = int(obj.find('difficult').text)
        bbox = obj.find('bndbox')
        obj_struct['bbox'] = [int(bbox.find('xmin').text),
                              int(bbox.find('ymin').text),
                              int(bbox.find('xmax').text),
                              int(bbox.find('ymax').text)]
        objects.append(obj_struct)

    return objects

def voc_ap(rec, prec, use_07_metric=False):
    """ ap = voc_ap(rec, prec, [use_07_metric])
    Compute VOC AP given precision and recall.
    If use_07_metric is true, uses the
    VOC 07 11 point method (default:False).
    """
    if use_07_metric:
        # 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.], rec, [1.]))
        mpre = np.concatenate(([0.], prec, [0.]))

        # compute the precision envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]

        # and sum (\Delta recall) * prec
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

def voc_eval(detpath,
             annopath,
             imagesetfile,
             classname,
             cachedir,
             ovthresh=0.5,
             use_07_metric=False):
    """rec, prec, ap = voc_eval(detpath,
                                annopath,
                                imagesetfile,
                                classname,
                                [ovthresh],
                                [use_07_metric])
    Top level function that does the PASCAL VOC evaluation.
    detpath: Path to detections
        detpath.format(classname) should produce the detection results file.
    annopath: Path to annotations
        annopath.format(imagename) should be the xml annotations file.
    imagesetfile: Text file containing the list of images, one image per line.
    classname: Category name (duh)
    cachedir: Directory for caching the annotations
    [ovthresh]: Overlap threshold (default = 0.5)
    [use_07_metric]: Whether to use VOC07's 11 point AP computation
        (default False)
    """
    # assumes detections are in detpath.format(classname)
    # assumes annotations are in annopath.format(imagename)
    # assumes imagesetfile is a text file with each line an image name
    # cachedir caches the annotations in a pickle file

    # first load gt
    if not os.path.isdir(cachedir):
        os.mkdir(cachedir)
    cachefile = os.path.join(cachedir, 'annots.pkl')
    
    # read list of images
    with open(imagesetfile, 'r') as f:
        lines = f.readlines()
    imagenames = [x.strip() for x in lines] #文件名
    
    if not os.path.isfile(cachefile):
        #print("zaybnzazazazazazazaza")
        # load annots
        recs = {}
        for i, imagename in enumerate(imagenames):
            recs[imagename] = parse_rec(annopath.format(imagename))
            if i % 100 == 0:
                print('Reading annotation for {:d}/{:d}'.format(
                    i + 1, len(imagenames)))
        # save
        print('Saving cached annotations to {:s}'.format(cachefile))
        with open(cachefile, 'wb') as f:
            cPickle.dump(recs, f)
    else:
        # load
        with open(cachefile, 'rb') as f:
            try:
                recs = cPickle.load(f)
            except EOFError:
                return 
            #recs = cPickle.load(f)
    # extract gt objects for this class
    class_recs = {}
    npos = 1
    for imagename in imagenames:
        R = [obj for obj in recs[imagename] if obj['name'] == classname]
        bbox = np.array([x['bbox'] for x in R])
        difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
        det = [False] * len(R)
        npos = npos + sum(~difficult)
        class_recs[imagename] = {'bbox': bbox,
                                 'difficult': difficult,
                                 'det': det}
    # read dets
    detfile = detpath.format(classname)
    with open(detfile, 'rb+') as f:
        lines = f.readlines()
    #print("type(lines[0]):",type(lines[0]))
    #print("type(x):",type(str(lines[0]).strip().split(" ")))
    splitlines = [str(x).strip().split(' ') for x in lines]
    
    #splitlines = splitlines.encode()
    #print(type(splitlines))
    image_ids = [x[0] for x in splitlines]
    confidence = np.array([float(x[1]) for x in splitlines])
    a = "\\n'"
#     for x in splitlines:
#         for z in x[2:]:
#             if a in z:
#                 print(z[:len(z)-3])
#             else:
#                 print(z)
            
            
    #remove \n
    BB = np.array([[float(z) if a not in z else float(z[:len(z)-3]) for z in x[2:]] for x in splitlines])
    #print(BB)

    # sort by confidence
    sorted_ind = np.argsort(-confidence)
    sorted_scores = np.sort(-confidence)
    BB = BB[sorted_ind, :]
    image_ids = [image_ids[x] for x in sorted_ind]

    # go down dets and mark TPs and FPs
    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    for d in range(nd):
        #print(image_ids[d][2:])
        #print(class_recs)
        R = class_recs[image_ids[d][2:]]
        bb = BB[d, :].astype(float)
        ovmax = -np.inf
        BBGT = R['bbox'].astype(float)

        if BBGT.size > 0:
            # compute overlaps
            # intersection
            ixmin = np.maximum(BBGT[:, 0], bb[0])
            iymin = np.maximum(BBGT[:, 1], bb[1])
            ixmax = np.minimum(BBGT[:, 2], bb[2])
            iymax = np.minimum(BBGT[:, 3], bb[3])
            iw = np.maximum(ixmax - ixmin + 1., 0.)
            ih = np.maximum(iymax - iymin + 1., 0.)
            inters = iw * ih

            # union
            uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
                   (BBGT[:, 2] - BBGT[:, 0] + 1.) *
                   (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)

            overlaps = inters / uni
            ovmax = np.max(overlaps)
            jmax = np.argmax(overlaps)

        if ovmax > ovthresh:
            if not R['difficult'][jmax]:
                if not R['det'][jmax]:
                    tp[d] = 1.
                    R['det'][jmax] = 1
                else:
                    fp[d] = 1.
            fp[d] = 1.

    # compute precision recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    rec = tp / float(npos)
    # avoid divide by zero in case the first detection matches a difficult
    # ground truth
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    ap = voc_ap(rec, prec, use_07_metric)

    return rec, prec, ap

再新建一个compute_mAP.py文件,内容如下:

from voc_eval import voc_eval

import os

current_path = os.getcwd()
results_path = current_path+"/results"
sub_files = os.listdir(results_path)

mAP = []
for i in range(len(sub_files)):
    class_name = sub_files[i].split(".txt")[0]
    rec, prec, ap = voc_eval('/home/ubantu/darknet/results/{}.txt', '/home/ubantu/darknet/datapig/Annotations/{}.xml', '/home/ubantu/darknet/datapig/ImageSets/Main/test.txt', class_name, '.')
    print("{} :\t {} ".format(class_name, ap))
    mAP.append(ap)

mAP = tuple(mAP)

print("***************************")
print("len(mAP):",len(mAP))
print("mAP :\t {}".format( float( sum(mAP)/len(mAP)) )) 

建好以后,在terminal中执行

python compute_mAP.py

就可以了。

这是我的结果:

大功告成了!
 

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值