Yolo v4--pytorch版本多图测试

Pytorch版 Yolo v4:
https://github.com/Tianxiaomo/pytorch-YOLOv4

COCO数据集预训练模型

yolov4.pth(https://pan.baidu.com/s/1ZroDvoGScDgtE1ja_QqJVw Extraction code:xrq9)
yolov4.conv.137.pth(https://pan.baidu.com/s/1ovBie4YyVQQoUrC3AY0joA Extraction code:kcel)

配置环境

conda create -n yolov4
安装requirements.txt依赖库

模型推断

python models.py <num_classes> <weightfile> <imgfile> <IN_IMAGE_H> <IN_IMAGE_W> <namefile(optional)>

修改代码,多图片预测

1 准备图片路径名

文本保存为train.txt,每行一个文件名

2 修改测试代码
class Yolov4(nn.Module):
    def __init__(self, yolov4conv137weight=None, n_classes=80, inference=False):
        super().__init__()

        output_ch = (4 + 1 + n_classes) * 3

        # backbone
        self.down1 = DownSample1()
        self.down2 = DownSample2()
        self.down3 = DownSample3()
        self.down4 = DownSample4()
        self.down5 = DownSample5()
        # neck
        self.neek = Neck(inference)
        # yolov4conv137
        if yolov4conv137weight:
            _model = nn.Sequential(self.down1, self.down2, self.down3, self.down4, self.down5, self.neek)
            pretrained_dict = torch.load(yolov4conv137weight)

            model_dict = _model.state_dict()
            # 1. filter out unnecessary keys
            pretrained_dict = {k1: v for (k, v), k1 in zip(pretrained_dict.items(), model_dict)}
            # 2. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict)
            _model.load_state_dict(model_dict)
        
        # head
        self.head = Yolov4Head(output_ch, n_classes, inference)


    def forward(self, input):
        d1 = self.down1(input)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)

        x20, x13, x6 = self.neek(d5, d4, d3)

        output = self.head(x20, x13, x6)
        return output


if __name__ == "__main__":
    import sys
    import cv2
    import pandas as pd
    from tqdm import tqdm
    import pickle

    weightfile = 'ckpt/yolov4.pth'
    train_data_path = 'data/small_data_list/small_data_0.2_test_data_list.txt'
    root_dir = '/home/flyingbird/Data/images'
    conv137weight = 'ckpt/yolov4.conv.137.pth'

    n_classes = 80
    width = 512
    height = 512

    all_box_dict = {}

    model = Yolov4(yolov4conv137weight=conv137weight, n_classes=n_classes, inference=True)

    pretrained_dict = torch.load(weightfile, map_location=torch.device('cuda'))
    model.load_state_dict(pretrained_dict)

    use_cuda = True
    if use_cuda:
        model.cuda()

    test_data_list = list(pd.read_csv(train_data_path, index_col=False, header=None, sep=' ')[0])

    for i in tqdm(test_data_list):
        img_path = os.path.join(root_dir, str(i)) + '.jpg'
        img = cv2.imread(img_path)

        # Inference input size is 416*416 does not mean training size is the same
        # Training size could be 608*608 or even other sizes
        # Optional inference sizes:
        #   Hight in {320, 416, 512, 608, ... 320 + 96 * n}
        #   Width in {320, 416, 512, 608, ... 320 + 96 * m}
        sized = cv2.resize(img, (width, height))
        sized = cv2.cvtColor(sized, cv2.COLOR_BGR2RGB)

        from tool.utils import load_class_names, plot_boxes_cv2
        from tool.torch_utils import do_detect

        for num in range(2):
            # This 'for' loop is for speed check
            # Because the first iteration is usually longer
            boxes = do_detect(model, sized, 0.00001, 0.1, use_cuda)

        # all_box_dict[i] = boxes[0]
        # print(len(boxes[0]))

        namesfile = 'data/coco.names'
        class_names = load_class_names(namesfile)
        save_path = os.path.join('predict/test', str(i)) + '.jpg'
        img, boxes_list = plot_boxes_cv2(img, boxes[0], save_path, class_names)
        # print('red one')
        all_box_dict[i] = boxes_list

    with open('predict/test_pred_box_dict_conf_0.00001_nms_0.1.txt', 'wb') as f:
        pickle.dump(all_box_dict, f)

3 结果输出保存

前面网络结构部分不用修改,后面测试时读取每张图,检测并输出box和置信度,筛选得到最终预测结果并保存在save_path中,预测结果的box用pickle保存为txt文件.

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值