Mask R-CNN

Mask R-CNN


  • -Model(pytorch版本

  • Introduction

1.视频教程:
B站、网易云课堂、腾讯课堂
2.代码地址:
Gitee
Github
3.存储地址:
Google云
百度云:
提取码:

download VOC 2012: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

一 论文导读

《Mask R-CNN》
—基于Mask R-CNN的目标检测
作者:Kaiming He, Georgia Gkioxari, Piotr Dollár, Ross Girshick
单位:Facebook AI Research(FAIR)
发表会议及时间:ICCV 2017

补充:Mask R-CNN在Faster-RCNN的基础上添加了分支网络,实现目标检测,目标像素分割的任务

Mask RCNN 可能算是RCNN系列的集大成者,是终结之作,其地位类似于语言模型中的BERT

RCNN系列:
RCNN–>Fast RCNN–>Faster RCNN–>Masker RCNN

  • Abstract

We present a conceptually simple, flexible, and general framework for object instance segmentation. Our approach efficiently detects objects in an image while simultaneously generating a high-quality segmentation mask for each instance. The method, called Mask R-CNN, extends Faster R-CNN by adding a branch for predicting an object mask in parallel with the existing branch for bounding box recognition. Mask R-CNN is simple to train and adds only a small overhead to Faster R-CNN, running at 5 fps. Moreover, Mask R-CNN is easy to generalize to other tasks, e.g., allowing us to estimate human poses in the same framework. We show top results in all three tracks of the COCO suite of challenges, including instance segmentation, bounding-box object detection, and person keypoint detection. Without bells and whistles, Mask R-CNN outperforms all existing, single-model entries on every task, including the COCO 2016 challenge winners. We hope our simple and effective approach will serve as a solid baseline and help ease future research in instance-level recognition. Code has been made available at: this https URL

1.将Roi Pooling层替换成了RoiAlign;
2.添加并列的FCN层(mask层);

Mask-RCNN特点:
1.在边框识别的基础上添加分支网络,用于语义Mask识别;
2.训练简单
3.可以方便的扩展到其他任务,比如人的姿态估计等;

RoIAlign的作用就是用双线性插值取代取整操作,从而使得每个RoI取得的特征能更好的对齐原图上的RoI区域
假设候选框坐标为左上角(0,9),右下角:(200,310),原图和featureMap的spaceRatio为10,那么映射到featureMap上的候选框为:左上角:(0,9/10),即为(0,0.9);右下角:(200/10,310/10),即为(20,31),那么候选框在特征图上的区域即为下图中红色区域。

Submission history
From: Kaiming He [view email]
[v1] Mon, 20 Mar 2017 17:53:38 UTC (6,270 KB)
[v2] Wed, 5 Apr 2017 20:14:55 UTC (7,041 KB)
[v3] Wed, 24 Jan 2018 07:54:08 UTC (7,061 KB)


paper

code

一 原理解析

在这里插入图片描述

在这里插入图片描述

class MaskRCNNPredictor(nn.Sequential):
    def __init__(self, in_channels, layers, dim_reduced, num_classes):
        """
        Arguments:
            in_channels (int)
            layers (Tuple[int])
            dim_reduced (int)
            num_classes (int)
        """

        d = OrderedDict()
        next_feature = in_channels
        for layer_idx, layer_features in enumerate(layers, 1):  # 首先是根据layers,定义了4个3x3的卷积,通道数都为256
            d['mask_fcn{}'.format(layer_idx)] = nn.Conv2d(next_feature, layer_features, 3, 1, 1)
            d['relu{}'.format(layer_idx)] = nn.ReLU(inplace=True)
            next_feature = layer_features
        # 用反卷积网络,stride=2,扩大了特征图到两倍大小
        d['mask_conv5'] = nn.ConvTranspose2d(next_feature, dim_reduced, 2, 2, 0)
        d['relu5'] = nn.ReLU(inplace=True)
        d['mask_fcn_logits'] = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)  # 用1x1卷积,获得分割结果,binary的,有cls那么多
        super().__init__(d)

        for name, param in self.named_parameters():  # 用kaiming_normal来初始化weight
            if 'weight' in name:
                nn.init.kaiming_normal_(param, mode='fan_out', nonlinearity='relu')

gt_mask = target['masks']  # 得到对应的mask(原始)
mask_loss = maskrcnn_loss(mask_logit, mask_proposal, pos_matched_idx, mask_label, gt_mask)
def maskrcnn_loss(mask_logit, proposal, matched_idx, label, gt_mask):
    matched_idx = matched_idx[:, None].to(proposal)  # 其中的.to用于统一数据类型
    roi = torch.cat((matched_idx, proposal), dim=1)

    M = mask_logit.shape[-1]
    gt_mask = gt_mask[:, None].to(roi)
    mask_target = roi_align(gt_mask, roi, 1., M, M, -1)[:, 0]  # 调用ROIAlign获得目标的mask

    idx = torch.arange(label.shape[0], device=label.device)
    mask_loss = F.binary_cross_entropy_with_logits(mask_logit[idx, label], mask_target)  # 计算mask部分的loss
    return mask_loss

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

二 代码实现

MyDataset

import xml.etree.ElementTree as ET
from PIL import Image
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

import torch
from torchvision import transforms

VOC_CLASSES = (
    "aeroplane", "bicycle", "bird", "boat", "bottle",
    "bus", "car", "cat", "chair", "cow", "diningtable",
    "dog", "horse", "motorbike", "person", "pottedplant",
    "sheep", "sofa", "train", "tvmonitor"
)



class BaseDataset:
    """
    Main class for Generalized Dataset.
    """

    def __init__(self, max_workers=2, verbose=False):
        self.max_workers = max_workers
        self.verbose = verbose

    def __getitem__(self, i):  # 输入一个index,通过子类声明的`get_image`函数和`get_target`函数来获得对应的image和target
        img_id = self.ids[i]
        image = self.get_image(img_id)
        image = transforms.ToTensor()(image)
        target = self.get_target(img_id) if self.train else {}
        return image, target

    def __len__(self):
        return len(self.ids)

    def check_dataset(self, checked_id_file):
        """
        use multithreads to accelerate the process.
        check the dataset to avoid some problems listed in method `_check`.
        """  # 用多线程的方式来实现函数`_check`的功能

        if os.path.exists(checked_id_file):
            info = [line.strip().split(", ") for line in open(checked_id_file)]
            self.ids, self.aspect_ratios = zip(*info)
            return

        since = time.time()
        print("Checking the dataset...")

        executor = ThreadPoolExecutor(max_workers=self.max_workers)
        seqs = torch.arange(len(self)).chunk(self.max_workers)
        tasks = [executor.submit(self._check, seq.tolist()) for seq in seqs]

        outs = []
        for future in as_completed(tasks):
            outs.extend(future.result())
        if not hasattr(self, "id_compare_fn"):
            self.id_compare_fn = lambda x: int(x)
        outs.sort(key=lambda x: self.id_compare_fn(x[0]))

        with open(checked_id_file, "w") as f:
            for img_id, aspect_ratio in outs:
                f.write("{}, {:.4f}\n".format(img_id, aspect_ratio))

        info = [line.strip().split(", ") for line in open(checked_id_file)]
        self.ids, self.aspect_ratios = zip(*info)
        print("checked id file: {}".format(checked_id_file))
        print("{} samples are OK; {:.1f} seconds".format(len(self), time.time() - since))

    def _check(self, seq):  # 判断一个图片的信息中,是否boxes,labels,masks都齐了
        out = []
        for i in seq:
            img_id = self.ids[i]
            target = self.get_target(img_id)
            boxes = target["boxes"]
            labels = target["labels"]
            masks = target["masks"]

            try:
                assert len(boxes) > 0, "{}: len(boxes) = 0".format(i)
                assert len(boxes) == len(labels), "{}: len(boxes) != len(labels)".format(i)
                assert len(boxes) == len(masks), "{}: len(boxes) != len(masks)".format(i)

                out.append((img_id, self._aspect_ratios[i]))
            except AssertionError as e:
                if self.verbose:
                    print(img_id, e)
        return out




class VOCDataset(BaseDataset):
    # download VOC 2012: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
    def __init__(self, data_dir, split, train=False):
        super().__init__()
        self.data_dir = data_dir
        self.split = split
        self.train = train

        # instances segmentation task
        id_file = os.path.join(data_dir, "ImageSets/Segmentation/{}.txt".format(split))  # 得到训练集的数据的id
        self.ids = [id_.strip() for id_ in open(id_file)]  # 并存在`self.ids`中
        self.id_compare_fn = lambda x: int(x.replace("_", ""))  # 替换掉数字中的`_```

        self.ann_file = os.path.join(data_dir, "Annotations/instances_{}.json".format(split))  # 得到ann文件的地址
        self._coco = None

        self.classes = VOC_CLASSES  # VOC中所存在的几种类
        # resutls' labels convert to annotation labels
        self.ann_labels = {self.classes.index(n): i for i, n in enumerate(self.classes)}

        checked_id_file = os.path.join(os.path.dirname(id_file), "checked_{}.txt".format(split))
        if train:
            if not os.path.exists(checked_id_file):
                self.make_aspect_ratios()  # 计算每一个图片的宽高比
            self.check_dataset(checked_id_file)  # 判断一个图片的信息中,是否boxes,labels,masks都齐了

    def make_aspect_ratios(self):
        self._aspect_ratios = []
        for img_id in self.ids:
            anno = ET.parse(os.path.join(self.data_dir, "Annotations", "{}.xml".format(img_id)))
            size = anno.findall("size")[0]
            width = size.find("width").text
            height = size.find("height").text
            ar = int(width) / int(height)
            self._aspect_ratios.append(ar)

    def get_image(self, img_id):  # 根据ID读入图片
        image = Image.open(os.path.join(self.data_dir, "JPEGImages/{}.jpg".format(img_id)))
        # image.show()
        return image.convert("RGB")

    def get_target(self, img_id):  # 根据ID获得训练目标,包括boxes,标签和masks
        masks = Image.open(os.path.join(self.data_dir, 'SegmentationObject/{}.png'.format(img_id)))  # mask的数值来自于图片
        # 压缩到0-1之间
        masks = transforms.ToTensor()(masks)
        # 找到mask中有哪几种不同的数值
        uni = masks.unique()
        # 其中在01之间的数值分别代表一个mask(实例分割)
        uni = uni[(uni > 0) & (uni < 1)]
        # 将图片中等于uni的部分分别转为值为0,1的图片
        masks = (masks == uni.reshape(-1, 1, 1)).to(torch.uint8)
        anno = ET.parse(os.path.join(self.data_dir, "Annotations", "{}.xml".format(img_id)))  # 对应的xml文件中读入
        boxes = []
        labels = []
        for obj in anno.findall("object"):  # 读入每一个图片的boxes和classes信息
            bndbox = obj.find("bndbox")
            bbox = [int(bndbox.find(tag).text) for tag in ["xmin", "ymin", "xmax", "ymax"]]
            name = obj.find("name").text
            label = self.classes.index(name)
            boxes.append(bbox)
            labels.append(label)

        boxes = torch.tensor(boxes, dtype=torch.float32)  # 转成tensor的数据类型
        labels = torch.tensor(labels)  # 转成tensor的数据类型

        img_id = torch.tensor([self.ids.index(img_id)])
        target = dict(image_id=img_id, boxes=boxes, labels=labels, masks=masks)  # 将所有的信息封装成dict,然后返回
        # print("masks shape",masks.size())
        return target

if __name__ == '__main__':
    voc_dataset = VOCDataset(r'.\VOCdevkit\VOC2012', 'train', train=True)
    print("len(voc_dataset)", len(voc_dataset))
    print(voc_dataset[0])


在这里插入图片描述

MaskRCNN 网络

from torch import nn
from torchvision import models
from torchvision.ops import misc
from collections import OrderedDict
import math

import torch
import torch.nn.functional as F

'''
数据处理
'''
def expand_detection(mask, box, padding):
    M = mask.shape[-1]
    scale = (M + 2 * padding) / M
    padded_mask = torch.nn.functional.pad(mask, (padding,) * 4)

    w_half = (box[:, 2] - box[:, 0]) * 0.5
    h_half = (box[:, 3] - box[:, 1]) * 0.5
    x_c = (box[:, 2] + box[:, 0]) * 0.5
    y_c = (box[:, 3] + box[:, 1]) * 0.5

    w_half = w_half * scale
    h_half = h_half * scale

    box_exp = torch.zeros_like(box)
    box_exp[:, 0] = x_c - w_half
    box_exp[:, 2] = x_c + w_half
    box_exp[:, 1] = y_c - h_half
    box_exp[:, 3] = y_c + h_half
    return padded_mask, box_exp.to(torch.int64)


def paste_masks_in_image(mask, box, padding, image_shape):
    mask, box = expand_detection(mask, box, padding)

    N = mask.shape[0]
    size = (N,) + tuple(image_shape)
    im_mask = torch.zeros(size, dtype=mask.dtype, device=mask.device)
    for m, b, im in zip(mask, box, im_mask):
        b = b.tolist()
        w = max(b[2] - b[0], 1)
        h = max(b[3] - b[1], 1)

        m = F.interpolate(m[None, None], size=(h, w), mode='bilinear', align_corners=False)[0][0]

        x1 = max(b[0], 0)
        y1 = max(b[1], 0)
        x2 = min(b[2], image_shape[1])
        y2 = min(b[3], image_shape[0])

        im[y1:y2, x1:x2] = m[(y1 - b[1]):(y2 - b[1]), (x1 - b[0]):(x2 - b[0])]
    return im_mask


class Transformer:
    def __init__(self, min_size, max_size, image_mean, image_std):
        self.min_size = min_size
        self.max_size = max_size
        self.image_mean = image_mean
        self.image_std = image_std

    def __call__(self, image, target):
        image = self.normalize(image)
        image, target = self.resize(image, target)
        image = self.batched_image(image)

        return image, target

    def normalize(self, image):
        if image.shape[0] == 1:
            image = image.repeat(3, 1, 1)

        dtype, device = image.dtype, image.device
        mean = torch.tensor(self.image_mean, dtype=dtype, device=device)
        std = torch.tensor(self.image_std, dtype=dtype, device=device)
        return (image - mean[:, None, None]) / std[:, None, None]

    def resize(self, image, target):
        ori_image_shape = image.shape[-2:]
        min_size = float(min(image.shape[-2:]))
        max_size = float(max(image.shape[-2:]))

        scale_factor = min(self.min_size / min_size, self.max_size / max_size)
        size = [round(s * scale_factor) for s in ori_image_shape]
        image = F.interpolate(image[None], size=size, mode='bilinear', align_corners=False)[0]

        if target is None:
            return image, target

        box = target['boxes']
        box[:, [0, 2]] = box[:, [0, 2]] * image.shape[-1] / ori_image_shape[1]
        box[:, [1, 3]] = box[:, [1, 3]] * image.shape[-2] / ori_image_shape[0]
        target['boxes'] = box

        if 'masks' in target:
            mask = target['masks']
            mask = F.interpolate(mask[None].float(), size=size)[0].byte()
            target['masks'] = mask

        return image, target

    def batched_image(self, image, stride=32):
        size = image.shape[-2:]
        max_size = tuple(math.ceil(s / stride) * stride for s in size)

        batch_shape = (image.shape[-3],) + max_size
        batched_img = image.new_full(batch_shape, 0)
        batched_img[:, :image.shape[-2], :image.shape[-1]] = image

        return batched_img[None]

    def postprocess(self, result, image_shape, ori_image_shape):
        box = result['boxes']
        box[:, [0, 2]] = box[:, [0, 2]] * ori_image_shape[1] / image_shape[1]
        box[:, [1, 3]] = box[:, [1, 3]] * ori_image_shape[0] / image_shape[0]
        result['boxes'] = box

        if 'masks' in result:
            mask = result['masks']
            mask = paste_masks_in_image(mask, box, 1, ori_image_shape)
            result['masks'] = mask

        return result


'''
确定backbone
'''
class ResBackbone(nn.Module):
    def __init__(self, backbone_name, pretrained):
        super().__init__()
        body = models.resnet.__dict__[backbone_name](
            pretrained=pretrained, norm_layer=misc.FrozenBatchNorm2d)

        for name, parameter in body.named_parameters():
            if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)

        self.body = nn.ModuleDict(d for i, d in enumerate(body.named_children()) if i < 8)
        in_channels = 2048
        self.out_channels = 256

        self.inner_block_module = nn.Conv2d(in_channels, self.out_channels, 1)
        self.layer_block_module = nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1)

        for m in self.children():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        for module in self.body.values():
            x = module(x)
        x = self.inner_block_module(x)
        x = self.layer_block_module(x)
        return x


'''
构建MaskRCNN网络准备
'''


class AnchorGenerator:
    def __init__(self, sizes, ratios):
        self.sizes = sizes
        self.ratios = ratios

        self.cell_anchor = None
        self._cache = {}

    def set_cell_anchor(self, dtype, device):
        if self.cell_anchor is not None:
            return
        sizes = torch.tensor(self.sizes, dtype=dtype, device=device)
        ratios = torch.tensor(self.ratios, dtype=dtype, device=device)

        h_ratios = torch.sqrt(ratios)  # 算出高上的变化比例
        w_ratios = 1 / h_ratios  # 算出宽上的变化比例

        hs = (sizes[:, None] * h_ratios[None, :]).view(-1)  # 分别算出变化后的9种高
        ws = (sizes[:, None] * w_ratios[None, :]).view(-1)  # 分别算出变化后的9种宽

        self.cell_anchor = torch.stack([-ws, -hs, ws, hs], dim=1) / 2  # 得到相对应中心点的anchors形状

    def grid_anchor(self, grid_size, stride):
        dtype, device = self.cell_anchor.dtype, self.cell_anchor.device
        shift_x = torch.arange(0, grid_size[1], dtype=dtype, device=device) * stride[1]  # 根据两个参数,计算相对左上角的位移
        shift_y = torch.arange(0, grid_size[0], dtype=dtype, device=device) * stride[0]

        y, x = torch.meshgrid(shift_y, shift_x)  # 变换成x,y坐标的形式
        x = x.reshape(-1)
        y = y.reshape(-1)
        shift = torch.stack((x, y, x, y), dim=1).reshape(-1, 1, 4)  # 变换成4个位置上的shift

        anchor = (shift + self.cell_anchor).reshape(-1, 4)  # 加上之前计算的9中anchor的坐标
        return anchor

    def cached_grid_anchor(self, grid_size, stride):
        key = grid_size + stride
        if key in self._cache:  # 如果这种类型的key组合之前计算过,将直接返回缓存不再计算
            return self._cache[key]
        anchor = self.grid_anchor(grid_size, stride)

        if len(self._cache) >= 3:  # 如果缓存的大于等于3个,便清理缓存
            self._cache.clear()
        self._cache[key] = anchor
        return anchor

    def __call__(self, feature, image_size):
        dtype, device = feature.dtype, feature.device  # 获得feature的类型和在不在gpu上
        grid_size = tuple(feature.shape[-2:])  # 获得特征的尺寸
        stride = tuple(int(i / g) for i, g in zip(image_size, grid_size))  # 并计算长宽上所采用的stride为多少

        self.set_cell_anchor(dtype, device)  # 计算每个anchors的尺寸

        anchor = self.cached_grid_anchor(grid_size, stride)  # 根据这两个参数计算实际的anchors相对左上角的位置
        return anchor


class RPNHead(nn.Module):  # RPN部分的网络结构
    def __init__(self, in_channels, num_anchors):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, 3, 1, 1)  # cls中先做一个3x3的卷积
        self.cls_logits = nn.Conv2d(in_channels, num_anchors, 1)  # 接着通过一个1x1的卷积,获得每一个anchors对应的cls值
        self.bbox_pred = nn.Conv2d(in_channels, 4 * num_anchors,
                                   1)  # bbox_pred是通过一个1x1的卷积,获得4*num_anchors维度的对于原有anchor的调整

        for l in self.children():  # 对RPN中的参数进行初始化
            nn.init.normal_(l.weight, std=0.01)
            nn.init.constant_(l.bias, 0)

    def forward(self, x):
        x = F.relu(self.conv(x))  # 前两个有关cls,代表RPN图中的上一行
        logits = self.cls_logits(x)
        bbox_reg = self.bbox_pred(x)  # 这一行有关bbox_pred,代表RPN图中的下一行
        return logits, bbox_reg


class Matcher:
    def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
        self.high_threshold = high_threshold
        self.low_threshold = low_threshold
        self.allow_low_quality_matches = allow_low_quality_matches

    def __call__(self, iou):
        """
        Arguments:
            iou (Tensor[M, N]): containing the pairwise quality between
            M ground-truth boxes and N predicted boxes.

        Returns:
            label (Tensor[N]): positive (1) or negative (0) label for each predicted box,
            -1 means ignoring this box.
            matched_idx (Tensor[N]): indices of gt box matched by each predicted box.
        """

        value, matched_idx = iou.max(dim=0)
        label = torch.full((iou.shape[1],), -1, dtype=torch.float, device=iou.device)

        label[value >= self.high_threshold] = 1  # 将大于high_threshold的位置定为1
        label[value < self.low_threshold] = 0  # 将小于low_threshold的位置定为0;其余剩下的保持-1

        if self.allow_low_quality_matches:  # 该部分将最大的IOU值保留作为1,以便当没有IOU大于high_threshold时,仍然有正样本
            highest_quality = iou.max(dim=1)[0]
            gt_pred_pairs = torch.where(iou == highest_quality[:, None])[1]
            label[gt_pred_pairs] = 1

        return label, matched_idx


class BalancedPositiveNegativeSampler:
    def __init__(self, num_samples, positive_fraction):
        self.num_samples = num_samples
        self.positive_fraction = positive_fraction

    def __call__(self, label):
        positive = torch.where(label == 1)[0]
        negative = torch.where(label == 0)[0]

        num_pos = int(self.num_samples * self.positive_fraction)  # 通过采样点数目和positive_fraction计算正样本应有多少
        num_pos = min(positive.numel(), num_pos)  # 比较positive样本的数量和应有正样本数量大小,选取小的那个
        num_neg = self.num_samples - num_pos  # 根据实际采样的正样本数量,计算需要多少负样本
        num_neg = min(negative.numel(), num_neg)  # 比较negative样本的数量和应有负样本数量大小,选取小的那个

        pos_perm = torch.randperm(positive.numel(), device=positive.device)[:num_pos]  # 根据上面计算的数量,随机应提取样本的index
        neg_perm = torch.randperm(negative.numel(), device=negative.device)[:num_neg]

        pos_idx = positive[pos_perm]
        neg_idx = negative[neg_perm]

        return pos_idx, neg_idx


class RegionProposalNetwork(nn.Module):
    def __init__(self, anchor_generator, head,
                 fg_iou_thresh, bg_iou_thresh,
                 num_samples, positive_fraction,
                 reg_weights,
                 pre_nms_top_n, post_nms_top_n, nms_thresh):
        super().__init__()

        self.anchor_generator = anchor_generator  # 用于产生作为基础的anchor的位置
        self.head = head

        self.proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh,
                                        allow_low_quality_matches=True)  # 返回一张表说明该anchors是正样本还是负样本
        self.fg_bg_sampler = BalancedPositiveNegativeSampler(num_samples,
                                                             positive_fraction)  # 根据label(上一个函数的返回结果),和所需样本数及正样本比重,返回正样本和负样本的index
        self.box_coder = BoxCoder(reg_weights)  # 该函数用于将RPN的bbox预测值和anchor_generator编解码成实际的bbox的位置

        self._pre_nms_top_n = pre_nms_top_n
        self._post_nms_top_n = post_nms_top_n
        self.nms_thresh = nms_thresh
        self.min_size = 1

    def create_proposal(self, anchor, objectness, pred_bbox_delta, image_shape):
        if self.training:
            pre_nms_top_n = self._pre_nms_top_n['training']
            post_nms_top_n = self._post_nms_top_n['training']
        else:
            pre_nms_top_n = self._pre_nms_top_n['testing']
            post_nms_top_n = self._post_nms_top_n['testing']

        pre_nms_top_n = min(objectness.shape[0], pre_nms_top_n)  # 输出的nms数量,等于目标的数目和设定的nms数量的最小值
        top_n_idx = objectness.topk(pre_nms_top_n)[1]  # 找到objectness中最大的pre_nms_top_n个的index
        score = objectness[top_n_idx]  # 找出最高这几个的cls值
        proposal = self.box_coder.decode(pred_bbox_delta[top_n_idx], anchor[top_n_idx])  # 解码为相对于pred_bbox_delta的bbox的位置

        proposal, score = process_box(proposal, score, image_shape,
                                      self.min_size)  # 该函数的作用在于使得bbox不超过图片的范围,且删除一些宽高小于min_size的bbox
        keep = nms(proposal, score, self.nms_thresh)[:post_nms_top_n]  # 实现了非极大值抑制,最多取post_nms_top_n个
        proposal = proposal[keep]
        return proposal

    def compute_loss(self, objectness, pred_bbox_delta, gt_box, anchor):
        iou = box_iou(gt_box, anchor)  # 计算gt与预先设定的anchor之间的iou
        label, matched_idx = self.proposal_matcher(iou)  # 返回是属于正样本还是负样本

        pos_idx, neg_idx = self.fg_bg_sampler(label)  # 根据采样总数和正样本比例,找出正负样本的index
        idx = torch.cat((pos_idx, neg_idx))  # 获得总的用来训练的index
        regression_target = self.box_coder.encode(gt_box[matched_idx[pos_idx]], anchor[pos_idx])  # 将gt的位置编码为与模型输出的位置一致

        objectness_loss = F.binary_cross_entropy_with_logits(objectness[idx], label[idx])  # 计算cls的loss
        box_loss = F.l1_loss(pred_bbox_delta[pos_idx], regression_target,
                             reduction='sum') / idx.numel()  # 用l1范数计算bbox部分的loss

        return objectness_loss, box_loss

    def forward(self, feature, image_shape, target=None):
        if target is not None:
            gt_box = target['boxes']
        anchor = self.anchor_generator(feature, image_shape)  # 计算得到anchor的相对位置

        objectness, pred_bbox_delta = self.head(feature)  # 输入RPN的网络中,获得置信度和预测框
        objectness = objectness.permute(0, 2, 3, 1).flatten()  # 将cls的那一通道移到最后一位
        pred_bbox_delta = pred_bbox_delta.permute(0, 2, 3, 1).reshape(-1, 4)  # 将bbox的那一通道移到最后一位

        proposal = self.create_proposal(anchor, objectness.detach(), pred_bbox_delta.detach(), image_shape)
        if self.training:  # 如果是train阶段,计算此时的loss
            objectness_loss, box_loss = self.compute_loss(objectness, pred_bbox_delta, gt_box, anchor)
            return proposal, dict(rpn_objectness_loss=objectness_loss, rpn_box_loss=box_loss)

        return proposal, {}


def roi_align(features, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
    if torch.__version__ >= "1.5.0":
        return torch.ops.torchvision.roi_align(
            features, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, False)
    else:
        return torch.ops.torchvision.roi_align(
            features, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio)


class RoIAlign:
    """
    Performs Region of Interest (RoI) Align operator described in Mask R-CNN

    """

    def __init__(self, output_size, sampling_ratio):
        """
        Arguments:
            output_size (Tuple[int, int]): the size of the output after the cropping
                is performed, as (height, width)
            sampling_ratio (int): number of sampling points in the interpolation grid
                used to compute the output value of each pooled output bin. If > 0,
                then exactly sampling_ratio x sampling_ratio grid points are used. If
                <= 0, then an adaptive number of grid points are used (computed as
                ceil(roi_width / pooled_w), and likewise for height). Default: -1
        """

        self.output_size = output_size
        self.sampling_ratio = sampling_ratio  # 按照sampling_ratio x sampling_ratio的grid来做ROI Align
        self.spatial_scale = None

    def setup_scale(self, feature_shape, image_shape):
        if self.spatial_scale is not None:
            return

        possible_scales = []
        for s1, s2 in zip(feature_shape, image_shape):
            scale = 2 ** int(math.log2(s1 / s2))
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        self.spatial_scale = possible_scales[0]

    def __call__(self, feature, proposal, image_shape):
        """
        Arguments:
            feature (Tensor[N, C, H, W])
            proposal (Tensor[K, 4])
            image_shape (Torch.Size([H, W]))

        Returns:
            output (Tensor[K, C, self.output_size[0], self.output_size[1]])

        """
        idx = proposal.new_full((proposal.shape[0], 1), 0)
        roi = torch.cat((idx, proposal), dim=1)

        self.setup_scale(feature.shape[-2:], image_shape)
        return roi_align(feature.to(roi), roi, self.spatial_scale, self.output_size[0], self.output_size[1],
                         self.sampling_ratio)


class MaskRCNNPredictor(nn.Sequential):
    def __init__(self, in_channels, layers, dim_reduced, num_classes):
        """
        Arguments:
            in_channels (int)
            layers (Tuple[int])
            dim_reduced (int)
            num_classes (int)
        """

        d = OrderedDict()
        next_feature = in_channels
        for layer_idx, layer_features in enumerate(layers, 1):  # 首先是根据layers,定义了4个3x3的卷积,通道数都为256
            d['mask_fcn{}'.format(layer_idx)] = nn.Conv2d(next_feature, layer_features, 3, 1, 1)
            d['relu{}'.format(layer_idx)] = nn.ReLU(inplace=True)
            next_feature = layer_features
        # 用反卷积网络,stride=2,扩大了特征图到两倍大小
        d['mask_conv5'] = nn.ConvTranspose2d(next_feature, dim_reduced, 2, 2, 0)
        d['relu5'] = nn.ReLU(inplace=True)
        d['mask_fcn_logits'] = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)  # 用1x1卷积,获得分割结果,binary的,有cls那么多
        super().__init__(d)

        for name, param in self.named_parameters():  # 用kaiming_normal来初始化weight
            if 'weight' in name:
                nn.init.kaiming_normal_(param, mode='fan_out', nonlinearity='relu')


class FastRCNNPredictor(nn.Module):
    def __init__(self, in_channels, mid_channels, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(in_channels, mid_channels)  # 两个全链接网络
        self.fc2 = nn.Linear(mid_channels, mid_channels)
        self.cls_score = nn.Linear(mid_channels, num_classes)  # 之后一个输出类别
        self.bbox_pred = nn.Linear(mid_channels, num_classes * 4)  # 一个输出bbox,每一类都有4个,总共num_classes * 4个

    def forward(self, x):
        x = x.flatten(start_dim=1)  # 将7x7展平
        x = F.relu(self.fc1(x))  # 共享俩个全链接
        x = F.relu(self.fc2(x))
        score = self.cls_score(x)
        bbox_delta = self.bbox_pred(x)

        return score, bbox_delta


class BoxCoder:
    def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
        self.weights = weights
        self.bbox_xform_clip = bbox_xform_clip

    def encode(self, reference_box, proposal):
        """
        Encode a set of proposals with respect to some
        reference boxes

        Arguments:
            reference_boxes (Tensor[N, 4]): reference boxes
            proposals (Tensor[N, 4]): boxes to be encoded
        """

        width = proposal[:, 2] - proposal[:, 0]
        height = proposal[:, 3] - proposal[:, 1]
        ctr_x = proposal[:, 0] + 0.5 * width
        ctr_y = proposal[:, 1] + 0.5 * height

        gt_width = reference_box[:, 2] - reference_box[:, 0]
        gt_height = reference_box[:, 3] - reference_box[:, 1]
        gt_ctr_x = reference_box[:, 0] + 0.5 * gt_width
        gt_ctr_y = reference_box[:, 1] + 0.5 * gt_height

        dx = self.weights[0] * (gt_ctr_x - ctr_x) / width
        dy = self.weights[1] * (gt_ctr_y - ctr_y) / height
        dw = self.weights[2] * torch.log(gt_width / width)
        dh = self.weights[3] * torch.log(gt_height / height)

        delta = torch.stack((dx, dy, dw, dh), dim=1)
        return delta

    def decode(self, delta, box):
        """
        From a set of original boxes and encoded relative box offsets,
        get the decoded boxes.

        Arguments:
            delta (Tensor[N, 4]): encoded boxes.
            boxes (Tensor[N, 4]): reference boxes.
        """

        dx = delta[:, 0] / self.weights[0]
        dy = delta[:, 1] / self.weights[1]
        dw = delta[:, 2] / self.weights[2]
        dh = delta[:, 3] / self.weights[3]

        dw = torch.clamp(dw, max=self.bbox_xform_clip)
        dh = torch.clamp(dh, max=self.bbox_xform_clip)

        width = box[:, 2] - box[:, 0]
        height = box[:, 3] - box[:, 1]
        ctr_x = box[:, 0] + 0.5 * width
        ctr_y = box[:, 1] + 0.5 * height

        pred_ctr_x = dx * width + ctr_x
        pred_ctr_y = dy * height + ctr_y
        pred_w = torch.exp(dw) * width
        pred_h = torch.exp(dh) * height

        xmin = pred_ctr_x - 0.5 * pred_w
        ymin = pred_ctr_y - 0.5 * pred_h
        xmax = pred_ctr_x + 0.5 * pred_w
        ymax = pred_ctr_y + 0.5 * pred_h

        target = torch.stack((xmin, ymin, xmax, ymax), dim=1)
        return target


def box_iou(box_a, box_b):
    """
    Arguments:
        boxe_a (Tensor[N, 4])
        boxe_b (Tensor[M, 4])

    Returns:
        iou (Tensor[N, M]): the NxM matrix containing the pairwise
            IoU values for every element in box_a and box_b
    """

    lt = torch.max(box_a[:, None, :2], box_b[:, :2])
    rb = torch.min(box_a[:, None, 2:], box_b[:, 2:])

    wh = (rb - lt).clamp(min=0)
    inter = wh[:, :, 0] * wh[:, :, 1]
    area_a = torch.prod(box_a[:, 2:] - box_a[:, :2], 1)
    area_b = torch.prod(box_b[:, 2:] - box_b[:, :2], 1)

    return inter / (area_a[:, None] + area_b - inter)


def process_box(box, score, image_shape, min_size):
    """
    Clip boxes in the image size and remove boxes which are too small.
    """

    box[:, [0, 2]] = box[:, [0, 2]].clamp(0, image_shape[1])
    box[:, [1, 3]] = box[:, [1, 3]].clamp(0, image_shape[0])

    w, h = box[:, 2] - box[:, 0], box[:, 3] - box[:, 1]
    keep = torch.where((w >= min_size) & (h >= min_size))[0]
    box, score = box[keep], score[keep]
    return box, score


def nms(box, score, threshold):
    """
    Arguments:
        box (Tensor[N, 4])
        score (Tensor[N]): scores of the boxes.
        threshold (float): iou threshold.

    Returns:
        keep (Tensor): indices of boxes filtered by NMS.
    """

    return torch.ops.torchvision.nms(box, score, threshold)


# just for test. It is too slow. Don't use it during train
def slow_nms(box, nms_thresh):
    idx = torch.arange(box.size(0))

    keep = []
    while idx.size(0) > 0:
        keep.append(idx[0].item())
        head_box = box[idx[0], None, :]
        remain = torch.where(box_iou(head_box, box[idx]) <= nms_thresh)[1]
        idx = idx[remain]

    return keep


class RoIHeads(nn.Module):
    def __init__(self, box_roi_pool, box_predictor,
                 fg_iou_thresh, bg_iou_thresh,
                 num_samples, positive_fraction,
                 reg_weights,
                 score_thresh, nms_thresh, num_detections):
        super().__init__()
        self.box_roi_pool = box_roi_pool  # 对应之前的`box_roi_pool = RoIAlign(output_size=(7, 7), sampling_ratio=2)`
        self.box_predictor = box_predictor  # 对应`FastRCNNPredictor(in_channels, mid_channels, num_classes)`

        self.mask_roi_pool = None
        self.mask_predictor = None

        self.proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh,
                                        allow_low_quality_matches=False)  # 返回一张表说明该anchors是正样本还是负样本
        self.fg_bg_sampler = BalancedPositiveNegativeSampler(num_samples,
                                                             positive_fraction)  # 根据label(上一个函数的返回结果),和所需样本数及正样本比重,返回正样本和负样本的index
        self.box_coder = BoxCoder(reg_weights)  # 该函数用于将RPN的bbox预测值和anchor_generator编解码成实际的bbox的位置

        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.num_detections = num_detections
        self.min_size = 1

    def has_mask(self):
        if self.mask_roi_pool is None:
            return False
        if self.mask_predictor is None:
            return False
        return True

    def select_training_samples(self, proposal, target):
        gt_box = target['boxes']
        gt_label = target['labels']
        proposal = torch.cat((proposal, gt_box))

        iou = box_iou(gt_box, proposal)  # # 计算gt与proposal之间的iou
        pos_neg_label, matched_idx = self.proposal_matcher(iou)  # 返回是属于正样本还是负样本
        pos_idx, neg_idx = self.fg_bg_sampler(pos_neg_label)  # 根据采样总数和正样本比例,找出正负样本的index
        idx = torch.cat((pos_idx, neg_idx))  # 获得总的用来训练的index

        regression_target = self.box_coder.encode(gt_box[matched_idx[pos_idx]],
                                                  proposal[pos_idx])  # 将gt的位置编码为与模型输出的位置一致
        proposal = proposal[idx]  # 得到所有用于训练的proposal
        matched_idx = matched_idx[idx]  # 得到所有用于训练的matched_idx
        label = gt_label[matched_idx]  # 得到所有用于训练的label
        num_pos = pos_idx.shape[0]
        label[num_pos:] = 0

        return proposal, matched_idx, label, regression_target

    def fastrcnn_inference(self, class_logit, box_regression, proposal, image_shape):
        N, num_classes = class_logit.shape

        device = class_logit.device
        pred_score = F.softmax(class_logit, dim=-1)  # 类别做softmax
        box_regression = box_regression.reshape(N, -1, 4)

        boxes = []
        labels = []
        scores = []
        for l in range(1, num_classes):  # 一类一类来,排除第一类(背景)
            score, box_delta = pred_score[:, l], box_regression[:, l]  # 选出那一类对应的

            keep = score >= self.score_thresh  # 只有预测值高于score_thresh时才会认为有那一类
            box, score, box_delta = proposal[keep], score[keep], box_delta[keep]  # 抽出此时对应的box,score,box
            box = self.box_coder.decode(box_delta, box)  # 解码box_delta为绝对位置

            box, score = process_box(box, score, image_shape,
                                     self.min_size)  # 该函数的作用在于使得bbox不超过图片的范围,且删除一些宽高小于min_size的bbox

            keep = nms(box, score, self.nms_thresh)[:self.num_detections]  # 实现了非极大值抑制,最多取num_detections个
            box, score = box[keep], score[keep]
            label = torch.full((len(keep),), l, dtype=keep.dtype, device=device)

            boxes.append(box)
            labels.append(label)
            scores.append(score)

        results = dict(boxes=torch.cat(boxes), labels=torch.cat(labels), scores=torch.cat(scores))  # 存入所有结果并输出
        return results

    def forward(self, feature, proposal, image_shape, target):
        if self.training:
            proposal, matched_idx, label, regression_target = self.select_training_samples(proposal,
                                                                                           target)  # 根据规则筛选出用于训练的目标

        box_feature = self.box_roi_pool(feature, proposal, image_shape)  # 用ROIAlign得到大小为[512,256,7,7]的特征图
        class_logit, box_regression = self.box_predictor(box_feature)  # 得到分类结果和bbox预测结果

        result, losses = {}, {}
        if self.training:
            classifier_loss, box_reg_loss = fastrcnn_loss(class_logit, box_regression, label,
                                                          regression_target)  # 返回得到cls和bbox的loss
            losses = dict(roi_classifier_loss=classifier_loss, roi_box_loss=box_reg_loss)
        else:
            result = self.fastrcnn_inference(class_logit, box_regression, proposal, image_shape)  # 用于推断的部分

        if self.has_mask():
            if self.training:
                num_pos = regression_target.shape[0]

                mask_proposal = proposal[:num_pos]  # 先把对应正样本的位置准备好
                pos_matched_idx = matched_idx[:num_pos]  # 先把对应正样本的index也准备好
                mask_label = label[:num_pos]  # 先把对应正样本的label也准备好

                '''
                # -------------- critial ----------------
                box_regression = box_regression[:num_pos].reshape(num_pos, -1, 4)
                idx = torch.arange(num_pos, device=mask_label.device)
                mask_proposal = self.box_coder.decode(box_regression[idx, mask_label], mask_proposal)
                # ---------------------------------------
                '''

                if mask_proposal.shape[0] == 0:  # 如果没有mask_proposal,则直接返回0
                    losses.update(dict(roi_mask_loss=torch.tensor(0)))
                    return result, losses
            else:  # 这里是inference部分,`result['boxes']``self.fastrcnn_inference`的bbox输出
                mask_proposal = result['boxes']

                if mask_proposal.shape[0] == 0:  # 如果没有预测框,则直接返回0
                    result.update(dict(masks=torch.empty((0, 28, 28))))
                    return result, losses

            mask_feature = self.mask_roi_pool(feature, mask_proposal, image_shape)  # 对应第二个ROIAlign,得到[?,256,14,14]
            mask_logit = self.mask_predictor(mask_feature)  # 输入分割网络,得到[?,cls+1,28,28]

            if self.training:
                gt_mask = target['masks']  # 得到对应的mask(原始)
                mask_loss = maskrcnn_loss(mask_logit, mask_proposal, pos_matched_idx, mask_label, gt_mask)
                losses.update(dict(roi_mask_loss=mask_loss))
            else:
                label = result['labels']
                idx = torch.arange(label.shape[0], device=label.device)
                mask_logit = mask_logit[idx, label]  # 按照label的顺序排列mask

                mask_prob = mask_logit.sigmoid()  # 并做sigmoid
                result.update(dict(masks=mask_prob))

        return result, losses


'''
损失函数
'''


def maskrcnn_loss(mask_logit, proposal, matched_idx, label, gt_mask):
    matched_idx = matched_idx[:, None].to(proposal)  # 其中的.to用于统一数据类型
    roi = torch.cat((matched_idx, proposal), dim=1)

    M = mask_logit.shape[-1]
    gt_mask = gt_mask[:, None].to(roi)
    mask_target = roi_align(gt_mask, roi, 1., M, M, -1)[:, 0]  # 调用ROIAlign获得目标的mask

    idx = torch.arange(label.shape[0], device=label.device)
    mask_loss = F.binary_cross_entropy_with_logits(mask_logit[idx, label], mask_target)  # 计算mask部分的loss
    return mask_loss


def fastrcnn_loss(class_logit, box_regression, label, regression_target):
    classifier_loss = F.cross_entropy(class_logit, label)  # 类别的loss用交叉熵来计算

    N, num_pos = class_logit.shape[0], regression_target.shape[0]
    box_regression = box_regression.reshape(N, -1, 4)
    box_regression, label = box_regression[:num_pos], label[
                                                      :num_pos]  # 由于`select_training_samples`中pos的index接在前面,所以前几个就代表pos的
    box_idx = torch.arange(num_pos, device=label.device)

    box_reg_loss = F.smooth_l1_loss(box_regression[box_idx, label], regression_target,
                                    reduction='sum') / N  # 对正样本计算bbox预测的loss

    return classifier_loss, box_reg_loss


def maskrcnn_loss(mask_logit, proposal, matched_idx, label, gt_mask):
    matched_idx = matched_idx[:, None].to(proposal)  # 其中的.to用于统一数据类型
    roi = torch.cat((matched_idx, proposal), dim=1)

    M = mask_logit.shape[-1]
    gt_mask = gt_mask[:, None].to(roi)
    mask_target = roi_align(gt_mask, roi, 1., M, M, -1)[:, 0]  # 调用ROIAlign获得目标的mask

    idx = torch.arange(label.shape[0], device=label.device)
    mask_loss = F.binary_cross_entropy_with_logits(mask_logit[idx, label], mask_target)  # 计算mask部分的loss
    return mask_loss
'''

构建MaskRCNN网络

'''
class MaskRCNN(nn.Module):
    """
    Implements Mask R-CNN.
    """

    def __init__(self, backbone, num_classes,  # 输入用于计算特征的backbone网络;分类网络分类类别的数目(加上背景一类)
                 # RPN parameters
                 rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
                 # 最小的anchor和GT box之间的IOU,大于它的被认为是正样本;最大的anchor和GT box之间的IOU,小于它的被认为是负样本
                 rpn_num_samples=256, rpn_positive_fraction=0.5,  # 在RPN训练中采样用于计算loss的anchor的数目;RPN训练中正样本所占的比例
                 rpn_reg_weights=(1., 1., 1., 1.),  # 用于编解码bounding boxes
                 rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,  # 在train中选择bbox中最高的的2000做nms,做完后保留前1000个
                 rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,  # test同理
                 rpn_nms_thresh=0.7,  # nms的阈值为0.7
                 # RoIHeads parameters
                 box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,  # 最小的实际的bbox与预测的IOU,高于这个被认为是正样本,反之为负样本
                 box_num_samples=512, box_positive_fraction=0.25,  # 训练时候用于计算loss所采样的数量,以及正样本的占比
                 box_reg_weights=(10., 10., 5., 5.),  # 用于编解码bounding boxes
                 box_score_thresh=0.1, box_nms_thresh=0.6,
                 box_num_detections=100):  # 只采用cls部分的score大于box_score_thresh的,nms的阈值为0.7,每一类最多检查100super().__init__()
        self.backbone = backbone
        out_channels = backbone.out_channels  # 256

        # ------------- RPN --------------------------
        anchor_sizes = (128, 256, 512)  # 所采用的anchors的基础大小有这三种
        anchor_ratios = (0.5, 1, 2)  # 所采用的anchors的长宽比有这三种
        num_anchors = len(anchor_sizes) * len(anchor_ratios)  # 总的anchors类型有9种
        rpn_anchor_generator = AnchorGenerator(anchor_sizes, anchor_ratios)
        rpn_head = RPNHead(out_channels, num_anchors)  # 声明RPN结构的网络部分

        # 声明RPN网络的proposal部分
        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
        self.rpn = RegionProposalNetwork(
            rpn_anchor_generator, rpn_head,
            rpn_fg_iou_thresh, rpn_bg_iou_thresh,
            rpn_num_samples, rpn_positive_fraction,
            rpn_reg_weights,
            rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)

        # ------------ RoIHeads --------------------------
        # 用于分类和bbox的部分,align到7x7
        box_roi_pool = RoIAlign(output_size=(7, 7), sampling_ratio=2)

        resolution = box_roi_pool.output_size[0]
        in_channels = out_channels * resolution ** 2
        mid_channels = 1024
        box_predictor = FastRCNNPredictor(in_channels, mid_channels, num_classes)  # 定义用于分类和bbox部分的网络

        self.head = RoIHeads(
            box_roi_pool, box_predictor,
            box_fg_iou_thresh, box_bg_iou_thresh,
            box_num_samples, box_positive_fraction,
            box_reg_weights,
            box_score_thresh, box_nms_thresh, box_num_detections)

        # 对应maskRCNN的mask预测部分,align到14x14 (因为作者是基于原有的fasterRCNN改的,所以通过以下的方式定义mask部分
        self.head.mask_roi_pool = RoIAlign(output_size=(14, 14), sampling_ratio=2)

        layers = (256, 256, 256, 256)
        dim_reduced = 256
        self.head.mask_predictor = MaskRCNNPredictor(out_channels, layers, dim_reduced, num_classes)

        # ------------ Transformer -------------------------- 将输入的图片缩放到固定的大小以及做归一化
        self.transformer = Transformer(
            min_size=800, max_size=1333,
            image_mean=[0.485, 0.456, 0.406],
            image_std=[0.229, 0.224, 0.225])

    def forward(self, image, target=None):
        ori_image_shape = image.shape[-2:]  # 记录图片的原始大小

        image, target = self.transformer(image, target)  # 将图片做变换
        image_shape = image.shape[-2:]
        feature = self.backbone(image)  # 通过backbone网络得到一个特征,作为RPN的输入和最后三层的输入

        proposal, rpn_losses = self.rpn(feature, image_shape, target)
        result, roi_losses = self.head(feature, proposal, image_shape, target)

        if self.training:
            return dict(**rpn_losses, **roi_losses)  # 训练下返回两个loss
        else:
            result = self.transformer.postprocess(result, image_shape, ori_image_shape)
            return result



if __name__ == '__main__':
    from MyDataset import VOCDataset
    # 指定设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 构建Dataset
    voc_dataset = VOCDataset(r'.\VOCdevkit\VOC2012', 'train', train=True)
    # 打印数据集个数
    print("len(voc_dataset)", len(voc_dataset))
    # for i, (image, target) in enumerate(data_loader):  # 每次读入一个数据
    # 获取第一个数据集的图像和标签
    image, target = voc_dataset[0]
    # 图像和标注数据转人设备
    image = image.to(device)  # 将image和target转换成对应的数据类型
    target = {k: v.to(device) for k, v in target.items()}
    # 构建backbone
    backbone = ResBackbone('resnet50',False)  # maskrcnn中用的backbone基于resnet50,并有预训练的模型
    # 构建网络MaskRCNN
    model = MaskRCNN(backbone, num_classes=20)
    # 模型转入设备
    model = model.to(device)
    # 前向传播得出损失
    losses = model(image, target)  # 将image和target输入模型,获得结果
    print(losses)

在这里插入图片描述

2.1 模型构建

2.2 模型训练

2.3 模型预测

2.4 模型验证

三 问题思索

四 改进方案

五 额外补充

获取数据流程

def __getitem__(self, i):  # 输入一个index,通过子类声明的`get_image`函数和`get_target`函数来获得对应的image和target
    img_id = self.ids[i]
    image = self.get_image(img_id)
    image = transforms.ToTensor()(image)
    target = self.get_target(img_id) if self.train else {}
    return image, target

step1:获取图像

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

make_aspect_ratios:计算每一个图片的宽高比

在这里插入图片描述
在这里插入图片描述

# 1.根路径
data_dir = r'.\VOCdevkit\VOC2012'

# 2.确定train.txt
id_file = os.path.join(data_dir, "ImageSets/Segmentation/{}.txt".format(split)) 
# 3 读取txt中所有图像名称 
self.ids = [id_.strip() for id_ in open(id_file)]

# 标注文件夹
self.ann_file = os.path.join(data_dir, "Annotations/instances_{}.json".format(split))  # 得到ann文件的地址

# 确定类别
self.classes = VOC_CLASSES  # VOC中所存在的几种类
# 标注标签 名称转类别数
self.ann_labels = {self.classes.index(n): i for i, n in enumerate(self.classes)}

    def get_image(self, img_id):  # 根据ID读入图片
        image = Image.open(os.path.join(self.data_dir, "JPEGImages/{}.jpg".format(img_id)))
        # image.show()
        return image.convert("RGB")

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

# 输入 masks shape :  1 334 500
# 将图片中等于uni的部分分别转为值为0,1的图片,同时将每个mask单独抽出来
masks = (masks == uni.reshape(-1, 1, 1)).to(torch.uint8)
# 输出 masks shape : 2 334 500
 masks = Image.open(os.path.join(self.data_dir, 'SegmentationObject/{}.png'.format("2007_000876")))  # mask的数值来自于图片
 # 压缩到0-1之间
 masks = transforms.ToTensor()(masks)
 # 找到mask中有哪几种不同的数值
 uni = masks.unique()
 # 其中在01之间的数值分别代表一个mask(实例分割)
 uni = uni[(uni > 0) & (uni < 1)]
 # 将图片中等于uni的部分分别转为值为0,1的图片
 res = uni.reshape(-1,1,1)
 masks = (masks == uni.reshape(-1, 1, 1)).to(torch.uint8)

 mask1 = masks.cuda().cpu().numpy()[0]
 mask2 = masks.cuda().cpu().numpy()[1]
 mask1[mask1 == 1] = 255
 m1 = Image.fromarray(mask1)
 m1.show()

 mask2[mask1 == 1] = 255
 m2 = Image.fromarray(mask2)
 m2.show()

在这里插入图片描述
在这里插入图片描述

# 读取xml标注
anno = ET.parse(os.path.join(self.data_dir, "Annotations", "{}.xml".format(img_id))) 
# 框 坐标
boxes = []
# 类别
labels = []
for obj in anno.findall("object"):  # 读入每一个图片的boxes和classes信息
    bndbox = obj.find("bndbox")
    bbox = [int(bndbox.find(tag).text) for tag in ["xmin", "ymin", "xmax", "ymax"]]
    name = obj.find("name").text
    label = self.classes.index(name)

    boxes.append(bbox)
    labels.append(label)
# 图像名称 目标框 类别id 分割掩膜
target = dict(image_id=img_id, boxes=boxes, labels=labels, masks=masks)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值