windows训练FasterRCNN
网上很多代码跑FasterRCNN都需要linux环境,但是linux环境在显卡驱动上配置较为麻烦
克隆代码
代码分为了pytorch与tensorflow两种版本,并通过sh文件调用
git clone https://github.com/trzy/FasterRCNN.git
下载预训练模型(可在download_models.sh中查看)
http://trzy.org/files/fasterrcnn/vgg16_caffe.pth
官方训练方式
python -m pytorch.FasterRCNN --train --learning-rate=1e-3 --epochs=10 --load-from=vgg16_caffe.pth --save-best-to=results_1.pth
python -m pytorch.FasterRCNN --train --learning-rate=1e-4 --epochs=4 --load-from=results_1.pth --save-best-to=results_final.pth
修改源码
我们可以直接以pytorch内的FasterRCNN代码作为项目根路径,运行sh的原理也相当于是运行__main__.py文件
- 可以修改__main__.py重命名为main.py
- 必须修改profile.py文件为myprofile.py(因为运行代码会报**AttributeError: module ‘profile’ has no attribute ‘run’**错误,最后检查为profile重名问题
),同时在main中更新导入 - 以pytoch作为根路径整个项目需要重新配置文件导入
- 修改datasets/voc.py
def _get_classes(self):
imageset_dir = os.path.join(self._dir, "ImageSets", "Main")
classes = set([ os.path.basename(path).split("_")[0] for path in Path(imageset_dir).glob("*_" + self.split + ".txt") ])
assert len(classes) > 0, "No classes found in ImageSets/Main for '%s' split" % self.split
class_index_to_name = { (1 + v[0]): v[1] for v in enumerate(sorted(classes)) }
class_index_to_name[0] = "background"
return class_index_to_name
# 因为该函数实现功能就是获取class_index_to_name,只不过加入了一些异常处理,但是只使用与VOC2007数据集,修改为
def _get_classes(self):
return return self.class_index_to_name
训练自己的数据集
已在VOC2028和VOC2020实验通过
https://github.com/njvisionpower/Safety-Helmet-Wearing-Dataset?tab=readme-ov-file
https://github.com/gengyanlei/fire-smoke-detect-yolov4/blob/master/readmes/README_ZN.md
根据voc.py中如下代码得知FastRCNN不支持黑白图
def _get_ground_truth_boxes(self, filepaths, allow_difficult):
gt_boxes_by_filepath = {}
for filepath in filepaths:
basename = os.path.splitext(os.path.basename(filepath))[0]
annotation_file = os.path.join(self._dir, "Annotations", basename) + ".xml"
tree = ET.parse(annotation_file)
root = tree.getroot()
assert tree != None, "Failed to parse %s" % annotation_file
assert len(root.findall("size")) == 1
size = root.find("size")
assert len(size.findall("depth")) == 1
depth = int(size.find("depth").text)
assert depth == 3
编写删除黑白图脚本
- 查询黑白图脚本
import os
import xml.etree.ElementTree as ET
def check_grayscale_images(folder_path):
grayscale_images = []
# 遍历文件夹下的所有文件
for file_name in os.listdir(folder_path):
annotation_file = os.path.join(folder_path, file_name)
tree = ET.parse(annotation_file)
root = tree.getroot()
assert tree != None, "Failed to parse %s" % annotation_file
assert len(root.findall("size")) == 1
size = root.find("size")
assert len(size.findall("depth")) == 1
depth = int(size.find("depth").text)
if depth != 3:
grayscale_images.append((file_name, 1))
return grayscale_images
# 指定文件夹路径
folder_path = "../datasets/VOC2028/Annotations"
# 检查灰度图像并输出文件名和深度
grayscale_images = check_grayscale_images(folder_path)
print("[")
for image_name, depth in grayscale_images:
print(f"'{image_name[:-4]}',")
print("]")
- 删除黑白图脚本
import os
ll = [
'000318',
'000324',
]
PATH = "../datasets/VOC2028/ImageSets/Main/"
with open(PATH + "test.txt", 'r') as f:
data = f.read().splitlines()
new_data = [line for line in data if line not in ll]
with open(PATH + "test.txt", 'w') as f:
f.write('\n'.join(new_data))
with open(PATH + "train.txt", 'r') as f:
data = f.read().splitlines()
new_data = [line for line in data if line not in ll]
with open(PATH + "train.txt", 'w') as f:
f.write('\n'.join(new_data))
with open(PATH + "trainval.txt", 'r') as f:
data = f.read().splitlines()
new_data = [line for line in data if line not in ll]
with open(PATH + "trainval.txt", 'w') as f:
f.write('\n'.join(new_data))
with open(PATH + "val.txt", 'r') as f:
data = f.read().splitlines()
new_data = [line for line in data if line not in ll]
with open(PATH + "val.txt", 'w') as f:
f.write('\n'.join(new_data))
Annotations = "../datasets/VOC2028/Annotations"
JPEGImages = "../datasets/VOC2028/JPEGImages"
def delete_files_in_folder(folder_path, file_ext):
# 遍历列表中的文件名
for file_name in ll:
file_path = os.path.join(folder_path, file_name + '.' + file_ext)
# 检查文件是否存在
if os.path.exists(file_path):
# 删除文件
os.remove(file_path)
print(f"Deleted file: {file_path}")
else:
print(f"File does not exist: {file_path}")
# 删除文件夹中指定的文件
delete_files_in_folder(Annotations, 'xml')
delete_files_in_folder(JPEGImages, 'jpg')
目前不清楚FastRCNN是否支持单标签图片(SSD并不支持),这里附带去除单标签图片程序
import argparse
import sys
import cv2
import os
import os.path as osp
import numpy as np
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET
parser = argparse.ArgumentParser(
description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
parser.add_argument('--root', default='../datasets/VOC2020/', help='Dataset root directory path')
args = parser.parse_args()
CLASSES = ['fire']
annopath = osp.join('%s', 'Annotations', '%s.{}'.format("xml"))
imgpath = osp.join('%s', 'JPEGImages', '%s.{}'.format("jpg"))
def vocChecker(image_id, width, height, keep_difficult=False):
target = ET.parse(annopath % image_id).getroot()
res = []
for obj in target.iter('object'):
difficult = int(obj.find('difficult').text) == 1
if not keep_difficult and difficult:
continue
name = obj.find('name').text.lower().strip()
bbox = obj.find('bndbox')
pts = ['xmin', 'ymin', 'xmax', 'ymax']
bndbox = []
for i, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text) - 1
# scale height or width
cur_pt = float(cur_pt) / width if i % 2 == 0 else float(cur_pt) / height
bndbox.append(cur_pt)
label_idx = dict(zip(CLASSES, range(len(CLASSES))))[name]
bndbox.append(label_idx)
res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind]
# img_id = target.find('filename').text[:-4]
try:
np.array(res)[:, 4]
np.array(res)[:, :4]
except IndexError:
print(image_id, " had error index")
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
if __name__ == '__main__':
i = 0
for name in sorted(os.listdir(osp.join(args.root, 'Annotations'))):
# as we have only one annotations file per image
i += 1
img = cv2.imread(imgpath % (args.root, name.split('.')[0]))
height, width, channels = img.shape
res = vocChecker((args.root, name.split('.')[0]), height, width)
print("Total of annotations : {}".format(i))
可以直接clone我的代码
git clone https://github.com/vivid-boy/FasterRCNN.git