yoloV5模型训练教程并进行量化

yoloV5模型训练教程

数据标注

数据标注我们要用labelimg

pip install labelimg

百度爬虫爬取图像

import os
import re
import sys
import urllib
import json
import socket
import urllib.request
import urllib.parse
import urllib.error
# 设置超时
from random import randint
import time

timeout = 5
socket.setdefaulttimeout(timeout)


class Crawler:
    # 睡眠时长
    __time_sleep = 0.1
    __amount = 0
    __start_amount = 0
    __counter = 0
    headers = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64; rv:23.0) Gecko/20100101 Firefox/23.0'}
    __per_page = 30

    # 获取图片url内容等
    # t 下载图片时间间隔
    def __init__(self, t=0.1):
        self.time_sleep = t

    # 获取后缀名
    @staticmethod
    def get_suffix(name):
        m = re.search(r'\.[^\.]*$', name)
        if m.group(0) and len(m.group(0)) <= 5:
            return m.group(0)
        else:
            return '.jpeg'

    # 保存图片
    def save_image(self, rsp_data, word):
        if not os.path.exists("./" + word):
            os.mkdir("./" + word)
        # 判断名字是否重复,获取图片长度
        self.__counter = len(os.listdir('./' + word)) + 1
        for image_info in rsp_data['data']:
            try:
                if 'replaceUrl' not in image_info or len(image_info['replaceUrl']) < 1:
                    continue
                obj_url = image_info['replaceUrl'][0]['ObjUrl']
                thumb_url = image_info['thumbURL']
                url = 'https://image.baidu.com/search/down?tn=download&ipn=dwnl&word=download&ie=utf8&fr=result&url=%s&thumburl=%s' % (
                    urllib.parse.quote(obj_url), urllib.parse.quote(thumb_url))
                time.sleep(self.time_sleep)
                suffix = self.get_suffix(obj_url)
                # 指定UA和referrer,减少403
                opener = urllib.request.build_opener()
                opener.addheaders = [
                    ('User-agent', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.116 Safari/537.36'),
                ]
                urllib.request.install_opener(opener)
                # 保存图片
                filepath = './{}/PME_{}_A{}'.format(word, randint(
                    1000000, 500000000), str(self.__counter) + str(suffix))
                for _ in range(5):
                    urllib.request.urlretrieve(url, filepath)
                    if os.path.getsize(filepath) >= 5:
                        break
                if os.path.getsize(filepath) < 5:
                    print("下载到了空文件,跳过!")
                    os.unlink(filepath)
                    continue
            except urllib.error.HTTPError as urllib_err:
                print(urllib_err)
                continue
            except Exception as err:
                time.sleep(1)
                print(err)
                print("产生未知错误,放弃保存")
                continue
            else:
                print("图+1,已有" + str(self.__counter) + "张图")
                self.__counter += 1
        return

    # 开始获取
    def get_images(self, word):
        search = urllib.parse.quote(word)
        # pn int 图片数
        pn = self.__start_amount
        while pn < self.__amount:

            url = 'https://image.baidu.com/search/acjson?tn=resultjson_com&ipn=rj&ct=201326592&is=&fp=result&queryWord=%s&cl=2&lm=-1&ie=utf-8&oe=utf-8&adpicid=&st=-1&z=&ic=&hd=&latest=&copyright=&word=%s&s=&se=&tab=&width=&height=&face=0&istype=2&qc=&nc=1&fr=&expermode=&force=&pn=%s&rn=%d&gsm=1e&1594447993172=' % (
                search, search, str(pn), self.__per_page)
            # 设置header防403
            try:
                time.sleep(self.time_sleep)
                req = urllib.request.Request(url=url, headers=self.headers)
                page = urllib.request.urlopen(req)
                rsp = page.read()
            except UnicodeDecodeError as e:
                print(e)
                print('-----UnicodeDecodeErrorurl:', url)
            except urllib.error.URLError as e:
                print(e)
                print("-----urlErrorurl:", url)
            except socket.timeout as e:
                print(e)
                print("-----socket timout:", url)
            else:
                # 解析json
                try:
                    rsp_data = json.loads(rsp)
                    self.save_image(rsp_data, word)
                    # 读取下一页
                    print("下载下一页")
                    pn += 60
                except Exception as e:
                    continue
            finally:
                page.close()
        print("下载任务结束")
        return

    def start(self, word, total_page=2, start_page=1, per_page=30):
        """
        爬虫入口
        :param word: 抓取的关键词
        :param total_page: 需要抓取数据页数 总抓取图片数量为 页数 x per_page
        :param start_page:起始页码
        :param per_page: 每页数量
        :return:
        """
        self.__per_page = per_page
        self.__start_amount = (start_page - 1) * self.__per_page
        self.__amount = total_page * self.__per_page + self.__start_amount
        self.get_images(word)


if __name__ == '__main__':

    crawler = Crawler(0.05)  # 抓取延迟为 0.05

    crawler.start('玩手机')


标注完成后,每张图像会生成对应的xml标注文件

数据预处理

创建convert_data.py文件,内容如下:

# -*- coding: utf-8 -*-

import xml.etree.ElementTree as ET
from tqdm import tqdm
import os
from os import getcwd


def convert(size, box):
    dw = 1. / (size[0])
    dh = 1. / (size[1])
    x = (box[0] + box[1]) / 2.0 - 1
    y = (box[2] + box[3]) / 2.0 - 1
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return x, y, w, h


def convert_annotation(image_id):
    # try:
    in_file = open('VOCData/images/{}.xml'.format(image_id), encoding='utf-8')
    out_file = open('VOCData/labels/{}.txt'.format(image_id),
                    'w', encoding='utf-8')
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text))
        b1, b2, b3, b4 = b
        # 标注越界修正
        if b2 > w:
            b2 = w
        if b4 > h:
            b4 = h
        b = (b1, b2, b3, b4)
        bb = convert((w, h), b)
        out_file.write(str(cls_id) + " " +
                       " ".join([str(a) for a in bb]) + '\n')
    # except Exception as e:
    #     print(e, image_id)


if __name__ == '__main__':

    sets = ['train', 'val']

    image_ids = [v.split('.')[0]
                 for v in os.listdir('VOCData/images/') if v.endswith('.xml')]

    split_num = int(0.95 * len(image_ids))

    classes = ['face', 'normal', 'phone', 'write',
               'smoke', 'eat', 'computer', 'sleep']

    if not os.path.exists('VOCData/labels/'):
        os.makedirs('VOCData/labels/')

    list_file = open('train.txt', 'w')
    for image_id in tqdm(image_ids[:split_num]):
        list_file.write('VOCData/images/{}.jpg\n'.format(image_id))
        convert_annotation(image_id)
    list_file.close()

    list_file = open('val.txt', 'w')
    for image_id in tqdm(image_ids[split_num:]):
        list_file.write('VOCData/images/{}.jpg\n'.format(image_id))
        convert_annotation(image_id)
    list_file.close()


运行结束后,可以看到VOCData/labels下生成了对应的txt文件

在data文件夹下创建myvoc.yaml文件

内容如下:

train: train.txt
val: val.txt

# number of classes
nc: 8

# class names
names: ["face", "normal", "phone", "write", "smoke", "eat", "computer", "sleep"]


下载预训练模型

我训练yolov5m这个模型,因此将它的预训练模型下载到weights文件夹下:

模型训练

修改models/yolov5m.yaml下的类别数:

python train.py --img 640 --batch 4 --epoch 300 --data ./data/myvoc.yaml --cfg ./models/yolov5m.yaml --weights weights/yolov5m.pt --workers 0

模型推理测试

训练结束后在 run/train/exp/weights 文件夹下会生成训练好的两个模型文件,我们将 last.pt 取出放到根目录下,然后运行:

python detect.py --source data/images --weights last.pt --conf 0.25

模型量化

这时我们注意到,训练好的 last.pt 有172MB,而官方给出的 yolov5m.pt 只有 40MB,这时候我们需要导出半精度模型重新保存,创建slim.py文件

python slim.py --in_weights last.pt --out_weights slim_model.pt --device 0
slim.py
import os
import torch

import torch
import torch.nn as nn
from tqdm import tqdm


def autopad(k, p=None):  
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    # Standard convolution
    # ch_in, ch_out, kernel, stride, padding, groups
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p),
                              groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.Hardswish() if act else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def fuseforward(self, x):
        return self.act(self.conv(x))


class Ensemble(nn.ModuleList):
    # Ensemble of models
    def __init__(self):
        super(Ensemble, self).__init__()

    def forward(self, x, augment=False):
        y = []
        for module in self:
            y.append(module(x, augment)[0])
        # y = torch.stack(y).max(0)[0]  # max ensemble
        # y = torch.cat(y, 1)  # nms ensemble
        y = torch.stack(y).mean(0)  # mean ensemble
        return y, None  # inference, train output


def attempt_load(weights, map_location=None):

    model = Ensemble()
    for w in weights if isinstance(weights, list) else [weights]:
        # load FP32 model
        model.append(torch.load(
            w, map_location=map_location)['model'].float().fuse().eval())

    # Compatibility updates
    for m in tqdm(model.modules()):
        if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
            m.inplace = True  # pytorch 1.7.0 compatibility
        elif type(m) is Conv:
            m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility

    if len(model) == 1:
        return model[-1]  # return model
    else:
        print('Ensemble created with %s\n' % weights)
        for k in ['names', 'stride']:
            setattr(model, k, getattr(model[-1], k))
        return model  # return ensemble


def select_device(device='', batch_size=None):
    # device = 'cpu' or '0' or '0,1,2,3'
    cpu_request = device.lower() == 'cpu'
    if device and not cpu_request:  # if device requested other than 'cpu'
        os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable
        assert torch.cuda.is_available(
        ), 'CUDA unavailable, invalid device %s requested' % device  # check availablity

    cuda = False if cpu_request else torch.cuda.is_available()
    if cuda:
        c = 1024 ** 2  # bytes to MB
        ng = torch.cuda.device_count()
        if ng > 1 and batch_size:  # check that batch_size is compatible with device_count
            assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (
                batch_size, ng)
        x = [torch.cuda.get_device_properties(i) for i in range(ng)]
        s = f'Using torch {torch.__version__} '
        for i in range(0, ng):
            if i == 1:
                s = ' ' * len(s)

    return torch.device('cuda:0' if cuda else 'cpu')


if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--in_weights', type=str,
                        default='last.pt', help='initial weights path')
    parser.add_argument('--out_weights', type=str,
                        default='slim_model.pt', help='output weights path')
    parser.add_argument('--device', type=str, default='0', help='device')
    opt = parser.parse_args()

    device = select_device(opt.device)
    model = attempt_load(opt.in_weights, map_location=device)
    model.to(device).eval()
    model.half()

    torch.save(model, opt.out_weights)
    print('done.')

    print('-[INFO] before: {} kb, after: {} kb'.format(
        os.path.getsize(opt.in_weights), os.path.getsize(opt.out_weights)))
  • 5
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值