目录
本例以检测足球为例
数据准备
数据采集
首先我们需要获取一些足球个照片,我这里使用如下脚本在flickr上获取,需要注意的是,如果你想使用该脚本,请登陆到flickr.com,点击搜索后复制请求头和请求参数进行替换。该脚本获取的图片数量取决于网址上有的数量和给定的参数。
import asyncio
import pathlib
import time
from concurrent.futures import ThreadPoolExecutor
import aiofiles
import httpx
from tqdm import tqdm
headers = {
"Host": "api.flickr.com",
"Connection": "keep-alive",
"sec-ch-ua": "\\"Chromium\\";v=\\"110\\", \\"Not A(Brand\\";v=\\"24\\", \\"Microsoft Edge\\";v=\\"110\\"",
"DNT": "1",
"sec-ch-ua-mobile": "?0",
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.63",
"sec-ch-ua-platform": "\\"macOS\\"",
"Accept": "*/*",
"Origin": "<https://www.flickr.com>",
"Sec-Fetch-Site": "same-site",
"Sec-Fetch-Mode": "cors",
"Sec-Fetch-Dest": "empty",
"Referer": "<https://www.flickr.com/>",
"Accept-Encoding": "gzip, deflate, br",
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6",
"Cookie": "xb=716910; sp=edba229b-a680-4e45-8953-b22e0b1d807a; __ssid=7e179ee83966a990a2475f22965ac5e; localization=zh-hk%3Bus%3Bus; ccc=%7B%22needsConsent%22%3Afalse%2C%22managed%22%3A0%2C%22changed%22%3A0%2C%22info%22%3A%7B%22cookieBlock%22%3A%7B%22level%22%3A0%2C%22blockRan%22%3A0%7D%7D%2C%22freshServerContext%22%3Atrue%7D; _sp_ses.df80=*; flrbp=1679324497-e7a4016caa17074fabbdaaa6571a040272a5d75e; flrbgrp=1679324497-c30919e68dab365435c113ce0411ea5b59c374b6; flrbgdrp=1679324497-6eec5e06b29b3b6f6c5c3882de81131abbe4f6ce; flrbgmrp=1679324497-413d84b71c022fd6b8238eb06e5609c64ce80f82; flrbrst=1679324497-896de7174620cfbe221989eebeb461beb5271889; flrtags=1679324497-0150a7c5b3f0fcf2fe7e96654d09cfd67704e1be; flrbrp=1679324497-76a155764becf1bcc4a67ba5fac96f4df1a603bd; flrb=36; vp=1532%2C946%2C2%2C0%2Ctag-photos-everyone-view%3A1226%2Csearch-photos-prints-view%3A886%2Csearch-photos-everyone-view%3A886; _sp_id.df80=20784097-d58b-4999-b8e5-734d21dab34b.1666181791.8.1679325014.1667110194.1dc51818-f960-4b25-ae5b-ad80c1d0dc66.87f8ed05-0901-495e-aa14-bdf2723d6190.e5e25774-f0d1-42d7-9fd4-e936d4a0445b.1679324485578.26",
}
def fetch_src_json(keyword, page=1, times=5) -> dict:
# print(f"keyword {keyword},第{page}页,已经请求 {6 - times} 次")
time.sleep(0.02)
if times == 0:
return {}
url = "<https://api.flickr.com/services/rest>"
params = {
"sort": "relevance",
"parse_tags": "1",
"content_types": "0,1,2,3",
"video_content_types": "0,1,2,3",
"extras": "can_comment,can_print,count_comments,count_faves,description,isfavorite,license,media,needs_interstitial,owner_name,path_alias,realname,rotation,url_sq,url_q,url_t,url_s,url_n,url_w,url_m,url_z,url_c,url_l",
"per_page": "50",
"page": "2",
"lang": "zh-HK",
"text": "足球",
"viewerNSID": "",
"method": "flickr.photos.search",
"csrf": "",
"api_key": "5431b12ac9f1f58ad25ec5209ae3197a",
"format": "json",
"hermes": "1",
"hermesClient": "1",
"reqId": "d986c176-716c-463b-ae40-ea05e1575e64",
"nojsoncallback": "1",
}
params.update(dict(text=f"{keyword}", page=f"{page}", per_page="75"))
try:
proxies = {
"http://": "<http://127.0.0.1:1087>",
"https://": "<http://127.0.0.1:1087>",
}
resp = httpx.get(url=url, headers=headers, params=params, verify=False, proxies=proxies)
if resp.status_code == 200:
result = resp.json()
if not result:
return fetch_src_json(keyword, page, times=times - 1)
return result
return fetch_src_json(keyword, page, times=times - 1)
except Exception as e:
return fetch_src_json(keyword, page, times=times - 1)
def get_image_links(resp):
def get_link(d: dict):
image_qualities = ["l", "c", "z", "m", "w", "n", "s", "t", "q", "sq"]
for quality_flag in image_qualities:
quality = f"url_{quality_flag}_cdn"
if quality in d:
return d.get(quality)
return ""
if resp:
return [url for url in [get_link(photo) for photo in resp.get("photos", {}).get("photo", [])] if
url != "" or url is not None]
return []
def get_page_size(resp: dict) -> int:
return resp.get("photos", {}).get("pages", 0)
async def save_media(url, folder, times=3):
if times == 0:
return
filename = url.split("/")[-1]
save_folder = pathlib.Path("dataset", "image", folder)
async with httpx.AsyncClient(verify=False) as client:
try:
resp = await client.get(url=url)
except Exception as e:
return await save_media(url, folder, times=times - 1)
if resp is None:
return await save_media(url, folder, times=times - 1)
save_folder.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(f"{str(save_folder)}/{filename}", mode="wb") as f:
await f.write(resp.content)
def fetch_all_media(urls, folder):
tasks = [save_media(url, folder) for url in urls]
if tasks:
asyncio.run(asyncio.wait(tasks))
def fetch_by_keyword(keyword, qty):
resp = fetch_src_json(keyword)
if not resp:
print(f"获取关键字为 {keyword} 的图片资源失败")
pages = get_page_size(resp)
count = 0
for page in tqdm(range(1, pages)):
if page != 1:
resp = fetch_src_json(keyword, page)
image_links = get_image_links(resp)
if not image_links:
resp = fetch_src_json(keyword, page)
image_links = get_image_links(resp)
print(f"{keyword}已有{count}张,新保存{len(image_links)}张")
if not image_links:
continue
count += len(image_links)
fetch_all_media(image_links, keyword)
if count >= qty:
break
if __name__ == '__main__':
keywords = ["足球"]
threadPool = ThreadPoolExecutor(8)
for keyword in keywords:
threadPool.submit(fetch_by_keyword, keyword, 20000)
# fetch_by_keyword(keyword, 20000)
threadPool.shutdown(wait=True)
数据清理
在采集到的图片中,不一定是所有图片都有我们检测的目标的,我们需要对图片进行过滤,确保用来训练的图片中包含我们需要检测的目标
数据整理
在完成数据清理后,我们需要将图片移动到固定文件夹归档,并重命名我们的图片名称,以及统一图片格式。这里使用如下脚本操作。
def move_img_to_jpg(base_path: str, keyword: str):
path_aim = Path(base_path, keyword)
if not path_aim.exists():
path_aim.mkdir()
# 这里根据需要调整
# aim_dirs = Path(base_path).glob(f"{keyword}_*")
aim_dirs = [Path(base_path, "足球")]
img_idx = 0
for aim_dir in aim_dirs:
image_paths: [PosixPath] = [*aim_dir.glob("*.jpg"), *aim_dir.glob("*.png"), *aim_dir.glob("*.jpeg")]
for image in image_paths:
img_idx += 1
re_path = f"{keyword}_{img_idx}"
print(f"rename {image.name} to {re_path}{image.suffix}")
image.rename(Path(path_aim, f"{re_path}{image.suffix}"))
aim_dir.rmdir()
if __name__ == '__main__':
base_path = "path_to_your_download_image"
move_img_to_jpg(base_path, "football")
数据标记
数据标记是指在我们整理后的图片中,标记出我们需要检测的目标。这里使用的标注工具是labelImg(通过 pip install labelImg 安装),标记界面如下,第一次打开之前,需要将打开目录指向图片目录,将存放目录指向将存放voc的目录。
将voc格式转成txt
由于imagelab使用的是voc格式,即xml表示,需要将xml转换成Yolo识别的txt格式,使用如下脚本转换格式,若想直接使用下面脚本,请自行调整对应路径。
def cvt_voc2yolo(base_path: str, class_dict: dict):
path_xml = Path(base_path, "xml")
path_label = Path(base_path, "labels")
if not path_label.exists():
path_label.mkdir()
file_xml = path_xml.glob("*.xml")
[_cvt_xml2yolo(file, class_dict) for file in file_xml]
def _cvt_xml2yolo(path_xml: Path, classes: dict):
path_label = Path(path_xml.parent.parent, "labels", f"{path_xml.stem}.txt")
if not path_label.exists():
path_label.touch()
with open(str(path_label.resolve()), 'w') as label_file:
with open(str(path_xml.resolve()), "r", encoding='UTF-8') as xml_file:
tree = ET.parse(xml_file)
root = tree.getroot()
size = root.find('size')
size_width = int(size.find('width').text)
size_height = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult) == 1:
continue
cls_id = classes[cls]
xmlbox = obj.find('bndbox')
b = [float(xmlbox.find('xmin').text),
float(xmlbox.find('xmax').text),
float(xmlbox.find('ymin').text),
float(xmlbox.find('ymax').text)]
if size_width == 0 or size_height == 0 or b[0] == b[1] or b[2] == b[3]:
print("不合理的图不再给labels ", path_xml.stem)
label_file.close()
path_label.unlink()
return 1
# 标注越界修正
b[0] = max(0, b[0])
b[1] = min(size_width, b[1])
b[2] = max(0, b[2])
b[3] = min(size_height, b[3])
txt_data = [round(((b[0] + b[1]) / 2.0 - 1) / size_width, 6),
round(((b[2] + b[3]) / 2.0 - 1) / size_height, 6),
round((b[1] - b[0]) / size_width, 6),
round((b[3] - b[2]) / size_height, 6)]
if txt_data[0] < 0 or txt_data[1] < 0 or txt_data[2] < 0 or txt_data[3] < 0:
print("不合理的图不再给labels ", path_xml.stem)
label_file.close()
path_label.unlink()
return 1
label_file.write(str(cls_id) + " " + " ".join([str(a) for a in txt_data]) + '\\n')
return 0
if __name__ == '__main__':
base_path = "dataset_path"
cvt_voc2yolo(base_path, {'football': 0})
数据划分
数据划分的意思是我们需要将一部分数据用做训练,一部分数据用做验证,一部分数据用做测试。这个根据自己的实际需求进行调整。这里划分后的目录结构如下(以train目录为例):
dataset/train
├── images
│ ├── football_1.jpg
│ ├── football_***.jpg
│ └── football_99.jpg
├── labels
│ ├── football_1.txt
│ ├── football_***.jpg
│ └── football_99.txt
└── xml
├── football_1.xml
├── football_***.xml
└── football_99.xml
到此数据准备工作就完成了。
代码准备
下载yolov5项目,首次下载建议直接使用coco128数据集测试一下环境。
git clone <https://github.com/ultralytics/yolov5.git>
训练网络
给定数据集配置
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: /opt/***/dataset # dataset root dir
train: train/images # train images (relative to 'path') 128 images
val: train/images # 这里验证集给的跟训练集一致,根据自己实际数据集划分调整
test: # test images (optional)
# Classes
names:
0: football
# 这里我们只有一个类别,故写成该格式
使用训练数据集训练网络
python train.py --data path_to_data_config/football.yaml
训练完成后,权重文件会保存到runs/train/exp*/weights/下
测试推理
python detect.py --source pending_predict_path --weights best.pt --conf 0.4 --img 640
推理结果会保存到目录runs/detect/exp*下