【深度学习】基于Yolo实现目标检测

目录

数据准备

数据采集

数据清理

数据整理

数据标记

将voc格式转成txt

数据划分

代码准备

训练网络

给定数据集配置

使用训练数据集训练网络

测试推理


本例以检测足球为例

数据准备

数据采集

首先我们需要获取一些足球个照片,我这里使用如下脚本在flickr上获取,需要注意的是,如果你想使用该脚本,请登陆到flickr.com,点击搜索后复制请求头和请求参数进行替换。该脚本获取的图片数量取决于网址上有的数量和给定的参数。


import asyncio
import pathlib
import time
from concurrent.futures import ThreadPoolExecutor

import aiofiles
import httpx
from tqdm import tqdm

headers = {
    "Host": "api.flickr.com",
    "Connection": "keep-alive",
    "sec-ch-ua": "\\"Chromium\\";v=\\"110\\", \\"Not A(Brand\\";v=\\"24\\", \\"Microsoft Edge\\";v=\\"110\\"",
    "DNT": "1",
    "sec-ch-ua-mobile": "?0",
    "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.63",
    "sec-ch-ua-platform": "\\"macOS\\"",
    "Accept": "*/*",
    "Origin": "<https://www.flickr.com>",
    "Sec-Fetch-Site": "same-site",
    "Sec-Fetch-Mode": "cors",
    "Sec-Fetch-Dest": "empty",
    "Referer": "<https://www.flickr.com/>",
    "Accept-Encoding": "gzip, deflate, br",
    "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6",
    "Cookie": "xb=716910; sp=edba229b-a680-4e45-8953-b22e0b1d807a; __ssid=7e179ee83966a990a2475f22965ac5e; localization=zh-hk%3Bus%3Bus; ccc=%7B%22needsConsent%22%3Afalse%2C%22managed%22%3A0%2C%22changed%22%3A0%2C%22info%22%3A%7B%22cookieBlock%22%3A%7B%22level%22%3A0%2C%22blockRan%22%3A0%7D%7D%2C%22freshServerContext%22%3Atrue%7D; _sp_ses.df80=*; flrbp=1679324497-e7a4016caa17074fabbdaaa6571a040272a5d75e; flrbgrp=1679324497-c30919e68dab365435c113ce0411ea5b59c374b6; flrbgdrp=1679324497-6eec5e06b29b3b6f6c5c3882de81131abbe4f6ce; flrbgmrp=1679324497-413d84b71c022fd6b8238eb06e5609c64ce80f82; flrbrst=1679324497-896de7174620cfbe221989eebeb461beb5271889; flrtags=1679324497-0150a7c5b3f0fcf2fe7e96654d09cfd67704e1be; flrbrp=1679324497-76a155764becf1bcc4a67ba5fac96f4df1a603bd; flrb=36; vp=1532%2C946%2C2%2C0%2Ctag-photos-everyone-view%3A1226%2Csearch-photos-prints-view%3A886%2Csearch-photos-everyone-view%3A886; _sp_id.df80=20784097-d58b-4999-b8e5-734d21dab34b.1666181791.8.1679325014.1667110194.1dc51818-f960-4b25-ae5b-ad80c1d0dc66.87f8ed05-0901-495e-aa14-bdf2723d6190.e5e25774-f0d1-42d7-9fd4-e936d4a0445b.1679324485578.26",
}

def fetch_src_json(keyword, page=1, times=5) -> dict:
    # print(f"keyword {keyword},第{page}页,已经请求 {6 - times} 次")
    time.sleep(0.02)
    if times == 0:
        return {}
    url = "<https://api.flickr.com/services/rest>"
    params = {
        "sort": "relevance",
        "parse_tags": "1",
        "content_types": "0,1,2,3",
        "video_content_types": "0,1,2,3",
        "extras": "can_comment,can_print,count_comments,count_faves,description,isfavorite,license,media,needs_interstitial,owner_name,path_alias,realname,rotation,url_sq,url_q,url_t,url_s,url_n,url_w,url_m,url_z,url_c,url_l",
        "per_page": "50",
        "page": "2",
        "lang": "zh-HK",
        "text": "足球",
        "viewerNSID": "",
        "method": "flickr.photos.search",
        "csrf": "",
        "api_key": "5431b12ac9f1f58ad25ec5209ae3197a",
        "format": "json",
        "hermes": "1",
        "hermesClient": "1",
        "reqId": "d986c176-716c-463b-ae40-ea05e1575e64",
        "nojsoncallback": "1",
    }
    params.update(dict(text=f"{keyword}", page=f"{page}", per_page="75"))
    try:
        proxies = {
            "http://": "<http://127.0.0.1:1087>",
            "https://": "<http://127.0.0.1:1087>",
        }
        resp = httpx.get(url=url, headers=headers, params=params, verify=False, proxies=proxies)
        if resp.status_code == 200:
            result = resp.json()
            if not result:
                return fetch_src_json(keyword, page, times=times - 1)
            return result
        return fetch_src_json(keyword, page, times=times - 1)
    except Exception as e:
        return fetch_src_json(keyword, page, times=times - 1)

def get_image_links(resp):
    def get_link(d: dict):
        image_qualities = ["l", "c", "z", "m", "w", "n", "s", "t", "q", "sq"]
        for quality_flag in image_qualities:
            quality = f"url_{quality_flag}_cdn"
            if quality in d:
                return d.get(quality)
        return ""

    if resp:
        return [url for url in [get_link(photo) for photo in resp.get("photos", {}).get("photo", [])] if
                url != "" or url is not None]
    return []

def get_page_size(resp: dict) -> int:
    return resp.get("photos", {}).get("pages", 0)

async def save_media(url, folder, times=3):
    if times == 0:
        return
    filename = url.split("/")[-1]
    save_folder = pathlib.Path("dataset", "image", folder)
    async with httpx.AsyncClient(verify=False) as client:
        try:
            resp = await client.get(url=url)
        except Exception as e:
            return await save_media(url, folder, times=times - 1)
        if resp is None:
            return await save_media(url, folder, times=times - 1)
        save_folder.mkdir(parents=True, exist_ok=True)
        async with aiofiles.open(f"{str(save_folder)}/{filename}", mode="wb") as f:
            await f.write(resp.content)

def fetch_all_media(urls, folder):
    tasks = [save_media(url, folder) for url in urls]
    if tasks:
        asyncio.run(asyncio.wait(tasks))

def fetch_by_keyword(keyword, qty):
    resp = fetch_src_json(keyword)
    if not resp:
        print(f"获取关键字为 {keyword} 的图片资源失败")
    pages = get_page_size(resp)
    count = 0
    for page in tqdm(range(1, pages)):
        if page != 1:
            resp = fetch_src_json(keyword, page)
        image_links = get_image_links(resp)
        if not image_links:
            resp = fetch_src_json(keyword, page)
            image_links = get_image_links(resp)
        print(f"{keyword}已有{count}张,新保存{len(image_links)}张")
        if not image_links:
            continue
        count += len(image_links)
        fetch_all_media(image_links, keyword)
        if count >= qty:
            break

if __name__ == '__main__':
    keywords = ["足球"]
    threadPool = ThreadPoolExecutor(8)
    for keyword in keywords:
        threadPool.submit(fetch_by_keyword, keyword, 20000)
        # fetch_by_keyword(keyword, 20000)
    threadPool.shutdown(wait=True)

数据清理

在采集到的图片中,不一定是所有图片都有我们检测的目标的,我们需要对图片进行过滤,确保用来训练的图片中包含我们需要检测的目标

数据整理

在完成数据清理后,我们需要将图片移动到固定文件夹归档,并重命名我们的图片名称,以及统一图片格式。这里使用如下脚本操作。


def move_img_to_jpg(base_path: str, keyword: str):
    path_aim = Path(base_path, keyword)
    if not path_aim.exists():
        path_aim.mkdir()
		# 这里根据需要调整
    # aim_dirs = Path(base_path).glob(f"{keyword}_*")
    aim_dirs = [Path(base_path, "足球")]
    img_idx = 0
    for aim_dir in aim_dirs:
        image_paths: [PosixPath] = [*aim_dir.glob("*.jpg"), *aim_dir.glob("*.png"), *aim_dir.glob("*.jpeg")]
        for image in image_paths:
            img_idx += 1
            re_path = f"{keyword}_{img_idx}"
            print(f"rename {image.name} to {re_path}{image.suffix}")
            image.rename(Path(path_aim, f"{re_path}{image.suffix}"))
        aim_dir.rmdir()
if __name__ == '__main__':
    base_path = "path_to_your_download_image"
    move_img_to_jpg(base_path, "football")

数据标记

数据标记是指在我们整理后的图片中,标记出我们需要检测的目标。这里使用的标注工具是labelImg(通过 pip install labelImg 安装),标记界面如下,第一次打开之前,需要将打开目录指向图片目录,将存放目录指向将存放voc的目录。

将voc格式转成txt

由于imagelab使用的是voc格式,即xml表示,需要将xml转换成Yolo识别的txt格式,使用如下脚本转换格式,若想直接使用下面脚本,请自行调整对应路径。


def cvt_voc2yolo(base_path: str, class_dict: dict):
    path_xml = Path(base_path, "xml")
    path_label = Path(base_path, "labels")
    if not path_label.exists():
        path_label.mkdir()
    file_xml = path_xml.glob("*.xml")
    [_cvt_xml2yolo(file, class_dict) for file in file_xml]

def _cvt_xml2yolo(path_xml: Path, classes: dict):
    path_label = Path(path_xml.parent.parent, "labels", f"{path_xml.stem}.txt")
    if not path_label.exists():
        path_label.touch()
    with open(str(path_label.resolve()), 'w') as label_file:
        with open(str(path_xml.resolve()), "r", encoding='UTF-8') as xml_file:
            tree = ET.parse(xml_file)
            root = tree.getroot()
            size = root.find('size')
            size_width = int(size.find('width').text)
            size_height = int(size.find('height').text)
            for obj in root.iter('object'):
                difficult = obj.find('difficult').text
                cls = obj.find('name').text
                if cls not in classes or int(difficult) == 1:
                    continue
                cls_id = classes[cls]
                xmlbox = obj.find('bndbox')
                b = [float(xmlbox.find('xmin').text),
                     float(xmlbox.find('xmax').text),
                     float(xmlbox.find('ymin').text),
                     float(xmlbox.find('ymax').text)]

                if size_width == 0 or size_height == 0 or b[0] == b[1] or b[2] == b[3]:
                    print("不合理的图不再给labels  ", path_xml.stem)
                    label_file.close()
                    path_label.unlink()
                    return 1
                # 标注越界修正
                b[0] = max(0, b[0])
                b[1] = min(size_width, b[1])
                b[2] = max(0, b[2])
                b[3] = min(size_height, b[3])
                txt_data = [round(((b[0] + b[1]) / 2.0 - 1) / size_width, 6),
                            round(((b[2] + b[3]) / 2.0 - 1) / size_height, 6),
                            round((b[1] - b[0]) / size_width, 6),
                            round((b[3] - b[2]) / size_height, 6)]
                if txt_data[0] < 0 or txt_data[1] < 0 or txt_data[2] < 0 or txt_data[3] < 0:
                    print("不合理的图不再给labels  ", path_xml.stem)
                    label_file.close()
                    path_label.unlink()
                    return 1
                label_file.write(str(cls_id) + " " + " ".join([str(a) for a in txt_data]) + '\\n')
        return 0
if __name__ == '__main__':
				base_path = "dataset_path"
				cvt_voc2yolo(base_path, {'football': 0})

数据划分

数据划分的意思是我们需要将一部分数据用做训练,一部分数据用做验证,一部分数据用做测试。这个根据自己的实际需求进行调整。这里划分后的目录结构如下(以train目录为例):


dataset/train
├── images
│   ├── football_1.jpg
│   ├── football_***.jpg
│   └── football_99.jpg
├── labels
│   ├── football_1.txt
│   ├── football_***.jpg
│   └── football_99.txt
└── xml
    ├── football_1.xml
    ├── football_***.xml
    └── football_99.xml

到此数据准备工作就完成了。

代码准备

下载yolov5项目,首次下载建议直接使用coco128数据集测试一下环境。


git clone <https://github.com/ultralytics/yolov5.git>

训练网络

给定数据集配置


# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: /opt/***/dataset  # dataset root dir
train: train/images  # train images (relative to 'path') 128 images
val: train/images  # 这里验证集给的跟训练集一致,根据自己实际数据集划分调整
test:  # test images (optional)

# Classes
names:
  0: football
# 这里我们只有一个类别,故写成该格式

使用训练数据集训练网络


python train.py --data path_to_data_config/football.yaml

训练完成后,权重文件会保存到runs/train/exp*/weights/

测试推理


python detect.py --source pending_predict_path --weights best.pt --conf 0.4 --img 640

推理结果会保存到目录runs/detect/exp*

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值