windows训练FasterRCNN


网上很多代码跑FasterRCNN都需要linux环境,但是linux环境在显卡驱动上配置较为麻烦

克隆代码

代码分为了pytorch与tensorflow两种版本,并通过sh文件调用

git clone https://github.com/trzy/FasterRCNN.git

下载预训练模型(可在download_models.sh中查看)
http://trzy.org/files/fasterrcnn/vgg16_caffe.pth
官方训练方式

python -m pytorch.FasterRCNN --train --learning-rate=1e-3 --epochs=10 --load-from=vgg16_caffe.pth --save-best-to=results_1.pth
python -m pytorch.FasterRCNN --train --learning-rate=1e-4 --epochs=4 --load-from=results_1.pth --save-best-to=results_final.pth

修改源码

我们可以直接以pytorch内的FasterRCNN代码作为项目根路径,运行sh的原理也相当于是运行__main__.py文件

  1. 可以修改__main__.py重命名为main.py
  2. 必须修改profile.py文件为myprofile.py(因为运行代码会报**AttributeError: module ‘profile’ has no attribute ‘run’**错误,最后检查为profile重名问题
    ),同时在main中更新导入
  3. 以pytoch作为根路径整个项目需要重新配置文件导入
  4. 修改datasets/voc.py
def _get_classes(self):
    imageset_dir = os.path.join(self._dir, "ImageSets", "Main")
    classes = set([ os.path.basename(path).split("_")[0] for path in Path(imageset_dir).glob("*_" + self.split + ".txt") ])
    assert len(classes) > 0, "No classes found in ImageSets/Main for '%s' split" % self.split
    class_index_to_name = { (1 + v[0]): v[1] for v in enumerate(sorted(classes)) }
    class_index_to_name[0] = "background"
    return class_index_to_name
# 因为该函数实现功能就是获取class_index_to_name,只不过加入了一些异常处理,但是只使用与VOC2007数据集,修改为
def _get_classes(self):
	return return self.class_index_to_name

训练自己的数据集

已在VOC2028和VOC2020实验通过
https://github.com/njvisionpower/Safety-Helmet-Wearing-Dataset?tab=readme-ov-file
https://github.com/gengyanlei/fire-smoke-detect-yolov4/blob/master/readmes/README_ZN.md
根据voc.py中如下代码得知FastRCNN不支持黑白图

  def _get_ground_truth_boxes(self, filepaths, allow_difficult):
    gt_boxes_by_filepath = {}
    for filepath in filepaths:
      basename = os.path.splitext(os.path.basename(filepath))[0]
      annotation_file = os.path.join(self._dir, "Annotations", basename) + ".xml"
      tree = ET.parse(annotation_file)
      root = tree.getroot()
      assert tree != None, "Failed to parse %s" % annotation_file
      assert len(root.findall("size")) == 1
      size = root.find("size")
      assert len(size.findall("depth")) == 1
      depth = int(size.find("depth").text)
      assert depth == 3

编写删除黑白图脚本

  1. 查询黑白图脚本
import os
import xml.etree.ElementTree as ET


def check_grayscale_images(folder_path):
    grayscale_images = []

    # 遍历文件夹下的所有文件
    for file_name in os.listdir(folder_path):
        annotation_file = os.path.join(folder_path, file_name)
        tree = ET.parse(annotation_file)
        root = tree.getroot()
        assert tree != None, "Failed to parse %s" % annotation_file
        assert len(root.findall("size")) == 1
        size = root.find("size")
        assert len(size.findall("depth")) == 1
        depth = int(size.find("depth").text)
        if depth != 3:
            grayscale_images.append((file_name, 1))

    return grayscale_images


# 指定文件夹路径
folder_path = "../datasets/VOC2028/Annotations"

# 检查灰度图像并输出文件名和深度
grayscale_images = check_grayscale_images(folder_path)
print("[")
for image_name, depth in grayscale_images:
    print(f"'{image_name[:-4]}',")
print("]")
  1. 删除黑白图脚本
import os

ll = [
'000318',
'000324',
]

PATH = "../datasets/VOC2028/ImageSets/Main/"

with open(PATH + "test.txt", 'r') as f:
    data = f.read().splitlines()
new_data = [line for line in data if line not in ll]
with open(PATH + "test.txt", 'w') as f:
    f.write('\n'.join(new_data))

with open(PATH + "train.txt", 'r') as f:
    data = f.read().splitlines()
new_data = [line for line in data if line not in ll]
with open(PATH + "train.txt", 'w') as f:
    f.write('\n'.join(new_data))

with open(PATH + "trainval.txt", 'r') as f:
    data = f.read().splitlines()
new_data = [line for line in data if line not in ll]
with open(PATH + "trainval.txt", 'w') as f:
    f.write('\n'.join(new_data))

with open(PATH + "val.txt", 'r') as f:
    data = f.read().splitlines()
new_data = [line for line in data if line not in ll]
with open(PATH + "val.txt", 'w') as f:
    f.write('\n'.join(new_data))

Annotations = "../datasets/VOC2028/Annotations"
JPEGImages = "../datasets/VOC2028/JPEGImages"


def delete_files_in_folder(folder_path, file_ext):
    # 遍历列表中的文件名
    for file_name in ll:
        file_path = os.path.join(folder_path, file_name + '.' + file_ext)
        # 检查文件是否存在
        if os.path.exists(file_path):
            # 删除文件
            os.remove(file_path)
            print(f"Deleted file: {file_path}")
        else:
            print(f"File does not exist: {file_path}")


# 删除文件夹中指定的文件
delete_files_in_folder(Annotations, 'xml')
delete_files_in_folder(JPEGImages, 'jpg')

目前不清楚FastRCNN是否支持单标签图片(SSD并不支持),这里附带去除单标签图片程序

import argparse
import sys
import cv2
import os
import os.path as osp
import numpy as np

if sys.version_info[0] == 2:
    import xml.etree.cElementTree as ET
else:
    import xml.etree.ElementTree as ET

parser = argparse.ArgumentParser(
    description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
parser.add_argument('--root', default='../datasets/VOC2020/', help='Dataset root directory path')
args = parser.parse_args()

CLASSES = ['fire']
annopath = osp.join('%s', 'Annotations', '%s.{}'.format("xml"))
imgpath = osp.join('%s', 'JPEGImages', '%s.{}'.format("jpg"))


def vocChecker(image_id, width, height, keep_difficult=False):
    target = ET.parse(annopath % image_id).getroot()
    res = []
    for obj in target.iter('object'):
        difficult = int(obj.find('difficult').text) == 1
        if not keep_difficult and difficult:
            continue
        name = obj.find('name').text.lower().strip()
        bbox = obj.find('bndbox')
        pts = ['xmin', 'ymin', 'xmax', 'ymax']
        bndbox = []
        for i, pt in enumerate(pts):
            cur_pt = int(bbox.find(pt).text) - 1
            # scale height or width
            cur_pt = float(cur_pt) / width if i % 2 == 0 else float(cur_pt) / height
            bndbox.append(cur_pt)
        label_idx = dict(zip(CLASSES, range(len(CLASSES))))[name]
        bndbox.append(label_idx)
        res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]
        # img_id = target.find('filename').text[:-4]
    try:
        np.array(res)[:, 4]
        np.array(res)[:, :4]
    except IndexError:
        print(image_id, " had error index")

    return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]


if __name__ == '__main__':
    i = 0
    for name in sorted(os.listdir(osp.join(args.root, 'Annotations'))):
        # as we have only one annotations file per image
        i += 1
        img = cv2.imread(imgpath % (args.root, name.split('.')[0]))
        height, width, channels = img.shape
        res = vocChecker((args.root, name.split('.')[0]), height, width)
    print("Total of annotations : {}".format(i))

可以直接clone我的代码

git clone https://github.com/vivid-boy/FasterRCNN.git
  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值