yolov1

论文: https://arxiv.org/abs/1506.02640https://arxiv.org/abs/1506.02640target的格式:7 * 7 * 30 前20个是类别,然后是

[box1_confidence , x,y,w,h,box2_confidence,x,y,w,h]

记得对输入图片进行resize处理

backbone

import torch
import torch.nn as nn

architecture_config = [
    (7, 64, 2, 3),
    "M",
    (3, 192, 1, 1),
    "M",
    (1, 128, 1, 0),
    (3, 256, 1, 1),
    (1, 256, 1, 0),
    (3, 512, 1, 1),
    "M",
    [(1, 256, 1, 0), (3, 512, 1, 1), 4],
    (1, 512, 1, 0),
    (3, 1024, 1, 1),
    "M",
    [(1, 512, 1, 0), (3, 1024, 1, 1), 2],
    (3, 1024, 1, 1),
    (3, 1024, 2, 1),
    (3, 1024, 1, 1),
    (3, 1024, 1, 1),
]


class CNN_BLOCK(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(CNN_BLOCK, self).__init__()
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.relu(self.bn(self.cnn(x)))


class Yolo_v1(nn.Module):
    def __init__(self, S=7, B=2, C=20):
        super(Yolo_v1, self).__init__()
        self.S = S
        self.B = B
        self.C = C
        self.cnn = self._create_cnn()
        self.fc = self._create_fc()

    def _create_cnn(self):
        in_channels = 3
        layers = []
        for layer in architecture_config:
            if type(layer) == str:
                layers += [
                    nn.MaxPool2d(kernel_size=2, stride=2)
                ]
            elif type(layer) == tuple:
                layers += [
                    CNN_BLOCK(in_channels=in_channels, out_channels=layer[1], kernel_size=layer[0], stride=layer[2],
                              padding=layer[3])
                ]
                in_channels = layer[1]
            else:
                conv1 = layer[0]
                conv2 = layer[1]
                for _ in range(layer[-1]):
                    layers += [
                        CNN_BLOCK(in_channels=in_channels, out_channels=conv1[1], kernel_size=conv1[0], stride=conv1[2],
                                  padding=conv1[3])
                    ]
                    in_channels = conv1[1]
                    layers += [
                        CNN_BLOCK(in_channels=in_channels, out_channels=conv2[1], kernel_size=conv2[0], stride=conv2[2],
                                  padding=conv2[3])
                    ]
                    in_channels = conv2[1]
        return nn.Sequential(*layers)

    def _create_fc(self):
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024 * self.S * self.S, 4096),
            nn.LeakyReLU(0.1),
            nn.Linear(4096, self.S * self.S * (self.C + self.B * 5))
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.fc(x)
        return x


输入batch * 3 * 448 * 448 输出 batch * 7 * 7 * 30

loss 按照论文

import torch
import torch.nn as nn

from utils import intersection_over_union


class Yolo_v1_loss(nn.Module):
    def __init__(self, S=7, B=2, C=20):
        super(Yolo_v1_loss, self).__init__()
        self.mse = nn.MSELoss(reduction='sum')
        self.S = S
        self.B = B
        self.C = C

        self.coord = 5
        self.noobj = 0.5

    def forward(self, predict, target):
        # predict: N,(S * S * (B * 5 + C))
        # target: N,S,S,(B * 5 + C)
        predict = predict.reshape(-1, self.S, self.S, self.B * 5 + self.C)
        """计算每个anchor预测的第一个box和实际box的iou值"""
        iou_b1 = intersection_over_union(predict[..., 21:25], target[..., 21:25])
        """计算每个anchor预测的第二个box和实际box的iou值"""
        iou_b2 = intersection_over_union(predict[..., 26:30], target[..., 21:25])

        ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
        iou_max, best_box_idx = torch.max(ious, dim=0)

        exists_box = target[..., 20:21]

        # ============= #
        #  开始计算loss  #
        # ============= #
        """bounding_box损失"""
        """先得到最优iou对应的box"""
        box_predict = exists_box * (best_box_idx * predict[..., 26:30] + (1 - best_box_idx) * predict[..., 21:25])
        box_target = exists_box * target[..., 21:25]
        """按照论文给的coord损失,首先将w和h开方"""
        box_predict[..., 2:4] = torch.sign(box_predict[..., 2:4]) * torch.sqrt(torch.abs(box_predict[..., 2:4] + 1e-6))
        box_target[..., 2:4] = torch.sqrt(box_target[..., 2:4])
        """得到box_loss"""
        box_loss = self.mse(
            # N * 4
            torch.flatten(box_predict, end_dim=-2),
            torch.flatten(box_target, end_dim=-2)
        )

        """confidence损失"""
        """先得到最优iou对应的confidence"""
        confidence_predict = best_box_idx * predict[..., 25:26] + (1 - best_box_idx) * predict[..., 20:21]
        confidence_target = target[..., 20:21]
        """这里只极大化目标位置的confidence,其他位置的损失由于太多了会影响,所以给个权重参数为noobj来弱化其他位置的损失"""
        confidence_loss = self.mse(
            torch.flatten(exists_box * confidence_predict),
            torch.flatten(exists_box * confidence_target)
        )
        no_confidence_loss = self.mse(
            torch.flatten((1 - exists_box) * predict[..., 20:21]),
            torch.flatten((1 - exists_box) * confidence_target)
        )
        no_confidence_loss += self.mse(
            torch.flatten((1 - exists_box) * predict[..., 25:26]),
            torch.flatten((1 - exists_box) * confidence_target)
        )

        """计算classes损失"""
        class_loss = self.mse(
            # N * S * S * 20
            torch.flatten(exists_box * predict[..., :20], end_dim=-2),
            torch.flatten(exists_box * target[..., :20], end_dim=-2)
        )

        """相加"""
        loss = self.coord * box_loss + confidence_loss + self.noobj * no_confidence_loss + class_loss

        return loss

utils

import xml.etree.ElementTree as ET
import os
import os.path
import numpy as np
import torch
import matplotlib.pyplot as plt  # 导入绘图包
import cv2 as cv

class_dict = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
              'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
              'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
class_dict = {name: i for i, name in enumerate(class_dict)}

class_list = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
              'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
              'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']


def parse_xml():
    xml_path = '../../VOCtrainval_06-Nov-2007/VOCdevkit/VOC2007/Annotations/'
    xml_file = os.listdir(xml_path)
    if not os.path.exists('labels'):
        os.makedirs('labels')
    for file in xml_file:
        with open('labels/' + file.replace('.xml', '.txt'), 'w') as f:
            root = ET.parse(xml_path + file).getroot()
            width = float(root.find('size/width').text)
            height = float(root.find('size/height').text)

            for child in root.findall('object'):
                """类别"""
                c = child.find('name').text
                c = class_dict[c]
                xmin = float(child.find('bndbox').find('xmin').text)
                ymin = float(child.find('bndbox').find('ymin').text)
                xmax = float(child.find('bndbox').find('xmax').text)
                ymax = float(child.find('bndbox').find('ymax').text)

                x_center = (xmin + xmax) / (2 * width)
                y_center = (ymin + ymax) / (2 * height)
                w = (xmax - xmin) / width
                h = (ymax - ymin) / height
                f.write(' '.join([str(c), str(x_center), str(y_center), str(w), str(h)]) + '\n')


def intersection_over_union(box1, box2, mode='center'):
    if mode == 'center':
        """x_center,y_center,w,h"""
        """xmin,ymin,xmax,ymax"""
        box1_x1 = box1[..., 0:1] - box1[..., 2:3] / 2
        box1_y1 = box1[..., 1:2] - box1[..., 3:4] / 2
        box1_x2 = box1[..., 0:1] + box1[..., 2:3] / 2
        box1_y2 = box1[..., 1:2] + box1[..., 3:4] / 2

        box2_x1 = box2[..., 0:1] - box2[..., 2:3] / 2
        box2_y1 = box2[..., 1:2] - box2[..., 3:4] / 2
        box2_x2 = box2[..., 0:1] + box2[..., 2:3] / 2
        box2_y2 = box2[..., 1:2] + box2[..., 3:4] / 2
    else:
        box1_x1 = box1[..., 0:1]
        box1_y1 = box1[..., 1:2]
        box1_x2 = box1[..., 0:1]
        box1_y2 = box1[..., 1:2]

        box2_x1 = box2[..., 0:1]
        box2_y1 = box2[..., 1:2]
        box2_x2 = box2[..., 0:1]
        box2_y2 = box2[..., 1:2]
    """计算交集面积"""
    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)
    intersection_area = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)

    """计算并集面积"""
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection_area / (box1_area + box2_area - intersection_area + 1e-6)


def non_max_suppression(bboxes, iou_threshold=0.5, threshold=0.4):
    # bboxes: [[class,confidence,x,y,w,h],...]
    bboxes = [box for box in bboxes if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_nms = []
    while bboxes:
        chosen_box = bboxes.pop(0)
        """类别不一样或者iou小于某一个阈值说明俩个box不是预测同一个物体"""
        bboxes = [
            box for box in bboxes
            if box[0] != chosen_box[0]
               or intersection_over_union(torch.tensor(chosen_box[2:6]), torch.tensor(box[2:6]) < iou_threshold)
        ]
        bboxes_nms.append(chosen_box)
    return bboxes_nms


def plot_box(boxes, img):
    H,W = img.shape[:2]
    plt.imshow(img)
    current_axis = plt.gca()
    for bbox in boxes:
        classes = bbox[0]
        confidence = round(bbox[1].item(),2)
        x = bbox[2]
        y = bbox[3]
        w = bbox[4]
        h = bbox[5]
        xmin = (x - w / 2) * W
        xmax = (x + w / 2) * W
        ymin = (y - h / 2) * H
        ymax = (y + h / 2) * H
        current_axis.add_patch(
            plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, color='green', fill=False, linewidth=2))
        current_axis.text(xmin, ymin, class_list[int(classes)] + ': {}'.format(confidence),
                          color='white', bbox={'facecolor': 'green', 'alpha': 1.0})
    plt.show()


def get_boxes(pre,S = 7):
    # pre.shape == 1 * 7 * 7 * 30
    '''[[class confidence,x,y,w,h],...]'''

    cell_indices = torch.arange(7).repeat(1, 7, 1).unsqueeze(-1)
    pre[...,21:22] = (pre[...,21:22] + cell_indices) / S
    pre[...,26:27] = (pre[...,26:27] + cell_indices) / S
    pre[..., 22:23] = (pre[..., 22:23] + cell_indices.permute(0, 2, 1, 3)) / S
    pre[..., 27:28] = (pre[..., 27:28] + cell_indices.permute(0, 2, 1, 3)) / S
    pre[...,23:25] = pre[...,23:25] / S
    pre[..., 28:30] = pre[..., 28:30] / S

    pre = pre.reshape(7,7,30)
    classes = torch.max(pre[..., :20],dim=-1).indices.unsqueeze(-1)
    box1 = pre[...,21:25]
    box2 = pre[...,26:30]
    confidence1 = pre[...,20:21]
    confidence2 = pre[...,25:26]
    new_box = torch.zeros((7*7*2,6))
    new_box[:49,2:6] = torch.flatten(box1,end_dim=-2)
    new_box[49:,2:6] = torch.flatten(box2,end_dim=-2)
    new_box[:49,0:1] = torch.flatten(classes,end_dim=-2)
    new_box[49:,0:1] = torch.flatten(classes,end_dim=-2)
    new_box[:49,1:2] = torch.flatten(confidence1,end_dim=-2)
    new_box[49:,1:2] = torch.flatten(confidence2,end_dim=-2)
    return non_max_suppression(new_box)

不知道要train多久,租了个服务器一直在跑

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值