P2PNet(代码阅读笔记)

P2PNet 代码阅读笔记

在这里插入图片描述

一、主干网络

在这里插入图片描述
主干网络采用的是VGG16

class BackboneBase_VGG(nn.Module):
    def __init__(self, backbone: nn.Module, num_channels: int, name: str, return_interm_layers: bool):
        super().__init__()
        features = list(backbone.features.children())
        if return_interm_layers:
            if name == 'vgg16_bn':
                self.body1 = nn.Sequential(*features[:13])
                self.body2 = nn.Sequential(*features[13:23])
                self.body3 = nn.Sequential(*features[23:33])
                self.body4 = nn.Sequential(*features[33:43])
            else:
                self.body1 = nn.Sequential(*features[:9])
                self.body2 = nn.Sequential(*features[9:16])
                self.body3 = nn.Sequential(*features[16:23])
                self.body4 = nn.Sequential(*features[23:30])
        else:
            if name == 'vgg16_bn':
                self.body = nn.Sequential(*features[:44])  # 16x down-sample
            elif name == 'vgg16':
                self.body = nn.Sequential(*features[:30])  # 16x down-sample
        self.num_channels = num_channels
        self.return_interm_layers = return_interm_layers

    def forward(self, tensor_list):
        out = []

        if self.return_interm_layers:
            xs = tensor_list
            for _, layer in enumerate([self.body1, self.body2, self.body3, self.body4]):
                xs = layer(xs)
                out.append(xs)

        else:
            xs = self.body(tensor_list)
            out.append(xs)
        return out


class Backbone_VGG(BackboneBase_VGG):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str, return_interm_layers: bool):
        if name == 'vgg16_bn':
            backbone = models.vgg16_bn(pretrained=True)
        elif name == 'vgg16':
            backbone = models.vgg16(pretrained=True)
        num_channels = 256
        super().__init__(backbone, num_channels, name, return_interm_layers)

VGG16和VGG16bn的差别在于是否在每次卷积后加入了BatchNormalization(批归一化层)。
下图中D为VGG16的结构。
在这里插入图片描述

1.1 VGG16

Backbone_VGG16(
  (body): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
  )
)

1.2 VGG16_bn

Backbone_VGG16_bn(
  (body): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): ReLU(inplace=True)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (19): ReLU(inplace=True)
    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (22): ReLU(inplace=True)
    
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (26): ReLU(inplace=True)
    (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (29): ReLU(inplace=True)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (32): ReLU(inplace=True)
    
    (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (36): ReLU(inplace=True)
    (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (39): ReLU(inplace=True)
    (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (42): ReLU(inplace=True)
    (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
)

二、P2PNet网络

实际的网络架构和论文中还是不一样的,代码中加入了FPN层,如下图所示:
在这里插入图片描述

2.1 P2PNet

# the defenition of the P2PNet model
class P2PNet(nn.Module):
    def __init__(self, backbone, row=2, line=2):
        super().__init__()
        self.backbone = backbone #VGG16_bn
        self.num_classes = 2 #类别,两类:物体和背景
        # the number of all anchor points 预测点
        num_anchor_points = row * line
		#回归预测点
        self.regression = RegressionModel(num_features_in=256, num_anchor_points=num_anchor_points)
        #分类预测类别、置信度
        self.classification = ClassificationModel(num_features_in=256, \
                                            num_classes=self.num_classes, \
                                            num_anchor_points=num_anchor_points)

        self.anchor_points = AnchorPoints(pyramid_levels=[3,], row=row, line=line)

        self.fpn = Decoder(256, 512, 512)

    def forward(self, samples: NestedTensor):
        # get the backbone features
        features = self.backbone(samples)
        # forward the feature pyramid
        features_fpn = self.fpn([features[1], features[2], features[3]])

        batch_size = features[0].shape[0]
        # run the regression and classification branch
        regression = self.regression(features_fpn[1]) * 100 # 8x
        classification = self.classification(features_fpn[1])
        anchor_points = self.anchor_points(samples).repeat(batch_size, 1, 1)
        # decode the points as prediction
        output_coord = regression + anchor_points
        output_class = classification
        out = {'pred_logits': output_class, 'pred_points': output_coord}
       
        return out

2.2 回归层

# the network frmawork of the regression branch
class RegressionModel(nn.Module):
    def __init__(self, num_features_in, num_anchor_points=4, feature_size=256):
        super(RegressionModel, self).__init__()

        self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

        self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act3 = nn.ReLU()

        self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act4 = nn.ReLU()

        self.output = nn.Conv2d(feature_size, num_anchor_points * 2, kernel_size=3, padding=1)
    # sub-branch forward
    def forward(self, x):
        #两次卷积
        out = self.conv1(x)
        out = self.act1(out)

        out = self.conv2(out)
        out = self.act2(out)

        out = self.output(out)

        out = out.permute(0, 2, 3, 1) #将通道数放在最后一列

        #view之前用了transpose、permute等,需要用contiguous()来返回一个contiguouscopy
        #reshape成bathsize,...,2列
        return out.contiguous().view(out.shape[0], -1, 2)

2.2 分类层

# the network frmawork of the classification branch
class ClassificationModel(nn.Module):
    def __init__(self, num_features_in, num_anchor_points=4, num_classes=80, prior=0.01, feature_size=256):
        super(ClassificationModel, self).__init__()

        self.num_classes = num_classes
        self.num_anchor_points = num_anchor_points

        self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

        self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act3 = nn.ReLU()

        self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act4 = nn.ReLU()

        self.output = nn.Conv2d(feature_size, num_anchor_points * num_classes, kernel_size=3, padding=1)
        self.output_act = nn.Sigmoid()
    # sub-branch forward
    def forward(self, x):
        out = self.conv1(x)
        out = self.act1(out)

        out = self.conv2(out)
        out = self.act2(out)

        out = self.output(out)

        out1 = out.permute(0, 2, 3, 1)

        batch_size, width, height, _ = out1.shape

        out2 = out1.view(batch_size, width, height, self.num_anchor_points, self.num_classes)

        return out2.contiguous().view(x.shape[0], -1, self.num_classes)

2.3 计算锚点

# generate the reference points in grid layout 计算参考点
def generate_anchor_points(stride=16, row=3, line=3):
    row_step = stride / row  #row_step =8/2=4
    line_step = stride / line #line_step = 8/2=4 

    #shift_x =(1,2)-0.5*4-8/2=[-2,2]
    shift_x = (np.arange(1, line + 1) - 0.5) * line_step - stride / 2
    #shift_x =(1,2)-0.5*4-8/2=[-2,2]
    shift_y = (np.arange(1, row + 1) - 0.5) * row_step - stride / 2

    #meshgrid([-2,2]      shift_x  =([-2,2]    shift_y  =([-2,-2]
    #         [-2,2])  =             [-2,2])              [ 2, 2])
    shift_x, shift_y = np.meshgrid(shift_x, shift_y)

    #shift_x.ravel() = [-2,2,-2,2],shift_y.ravel() = [-2,-2,2,2]
    #anchor_points = np.vstack([-2,2,-2,2],[-2,-2,2,2]) = [[-2, 2,-2, 2]
    #                                                      [-2,-2, 2, 2]]
    #anchor_points.transpose()= [[-2,-2]
    #                            [-2, 2]
    #                            [ 2,-2]
    #                            [ 2, 2]
    #                            ]
    anchor_points = np.vstack((
        #x.ravel()展平并返回视图,元素会改变
        shift_x.ravel(), shift_y.ravel()
    )).transpose() #转置成x,y的形似,原来是第一行为x,第二行为y

    return anchor_points

# shift the meta-anchor to get an acnhor points
def shift(shape, stride, anchor_points):
    #每隔8个像素取一个锚点
    shift_x = (np.arange(0, shape[1]) + 0.5) * stride #[4, 12, 20,...]
    shift_y = (np.arange(0, shape[0]) + 0.5) * stride #[4, 12, 20, ...]

    #np.meshgrid([4, 12, 20,...]
    #            [4, 12, 20,...])=
    #shift_x = [[4, 12, 20,...]
    #            ...
    #           [4, 12, 20,...]]
    #shift_Y = [[4, 4, 4,...]
    #           [12,12,12,...]
    #           [..., ..., ..]]
    shift_x, shift_y = np.meshgrid(shift_x, shift_y)

    #shift_x.ravel()= [4,12,20,...,4,12,20,...]
    #shift_x.ravel()= [4,4,,...,12,,,...]
    #shifts = [[4,12,20,...,4,12,20,...]
    #          [4,4,,...,12,,,...]].transpose()
    #=[[4,4]
    # [4,12]
    # [4,20]
    # [4,..]
    # [..,.]
    # [12,4]
    # [12,12]
    # [....]
    # [....]
    # ]
    #相当于整幅图像中每隔8个像素点区一个锚点
    shifts = np.vstack((
        shift_x.ravel(), shift_y.ravel()
    )).transpose()

    A = anchor_points.shape[0] #A=4
    K = shifts.shape[0] #K= img.shape[0]//8
    
    #all_anchor_points.shape=(1,4,2)+(K,1,2)=(K,4,2)
    all_anchor_points = (anchor_points.reshape((1, A, 2)) + shifts.reshape((1, K, 2)).transpose((1, 0, 2)))
    
    #all_anchor_points.shape=(K*4,2)图上的所有锚点
    all_anchor_points = all_anchor_points.reshape((K * A, 2))

    return all_anchor_points

# this class generate all reference points on all pyramid levels
class AnchorPoints(nn.Module):
    def __init__(self, pyramid_levels=None, strides=None, row=3, line=3):
        super(AnchorPoints, self).__init__()

        if pyramid_levels is None:
            self.pyramid_levels = [3, 4, 5, 6, 7]
        else:
            self.pyramid_levels = pyramid_levels

        if strides is None:
            self.strides = [2 ** x for x in self.pyramid_levels]

        self.row = row
        self.line = line

    def forward(self, image):
        image_shape = image.shape[2:]
        image_shape = np.array(image_shape)
        #这里将图像缩小8倍
        image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]

        #定义一个2列的空数组
        all_anchor_points = np.zeros((0, 2)).astype(np.float32)

        # get reference points for each level 循环每个fpn层
        for idx, p in enumerate(self.pyramid_levels):
            anchor_points = generate_anchor_points(2**p, row=self.row, line=self.line) #计算锚点的偏移值
            shifted_anchor_points = shift(image_shapes[idx], self.strides[idx], anchor_points) #图上的所有锚点
            all_anchor_points = np.append(all_anchor_points, shifted_anchor_points, axis=0)

        #在第0维上进行扩展
        all_anchor_points = np.expand_dims(all_anchor_points, axis=0)

        # send reference points to device ,返回所有的锚点
        if torch.cuda.is_available():
            return torch.from_numpy(all_anchor_points.astype(np.float32)).cuda()
        else:
            return torch.from_numpy(all_anchor_points.astype(np.float32))

三、数据加载与预处理

3.1 Dataset

def build_dataset(args):
    if args.dataset_file == 'SHHA':
        from crowd_datasets.SHHA.loading_data import loading_data
        return loading_data

    return None
def loading_data(data_root):
    # the pre-proccssing transform 归一化操作
    transform = standard_transforms.Compose([
        standard_transforms.ToTensor(), 
        standard_transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]),
    ])

    # create the training dataset
    train_set = SHHA(data_root, train=True, transform=transform, patch=True, flip=True)
    
    # create the validation dataset
    val_set = SHHA(data_root, train=False, transform=transform)

    return train_set, val_set
class SHHA(Dataset):
    def __init__(self, data_root, transform=None, train=False, patch=False, flip=False):

        #数据集文件路径
        self.root_path = data_root

        #训练集和验证集的图片路径列表
        self.train_lists = "train.list"
        self.eval_list = "test.list"

        # there may exist multiple list files如果存在多个list,这里不存在
        self.img_list_file = self.train_lists.split(',')
        if train:
            self.img_list_file = self.train_lists.split(',')
        else:
            self.img_list_file = self.eval_list.split(',')

        self.img_map = {}
        self.img_list = []

        # loads the image/gt pairs 装栽图片和真实点
        for _, train_list in enumerate(self.img_list_file):
            train_list = train_list.strip() #删除多余的空格
            with open(os.path.join(self.root_path, train_list)) as fin:
                for line in fin:
                    if len(line) < 2: #如果不是(图片路径 真实点文本路径)的格式则跳过 
                        continue
                    line = line.strip().split() #line[0]为图片路径,line[1]为真实点的坐标文本路径
                    #图片对应的文本
                    self.img_map[os.path.join(self.root_path, line[0].strip())] = \
                                    os.path.join(self.root_path, line[1].strip())
        #按图片名排序
        self.img_list = sorted(list(self.img_map.keys()))

        # number of samples
        self.nSamples = len(self.img_list)
        
        self.transform = transform
        self.train = train
        self.patch = patch
        self.flip = flip

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'

        img_path = self.img_list[index]
        gt_path = self.img_map[img_path]

        # load image and ground truth 返回cv2格式的图片和np数组的点坐标
        img, point = load_data((img_path, gt_path), self.train)


        # applu augumentation 图片归一化
        if self.transform is not None:
            img = self.transform(img)

        #若是训练
        if self.train:
            # data augmentation -> random scale 进行数据增强,随机选取一个规模
            scale_range = [0.7, 1.3] #随机因子的范围
            min_size = min(img.shape[1:])

            scale = random.uniform(*scale_range)

            # scale the image and points 对图像和点进行等比例缩放
            if scale * min_size > 128:
                img = torch.nn.functional.upsample_bilinear(img.unsqueeze(0), scale_factor=scale).squeeze(0)
                point *= scale

        # random crop augumentaiton 对图片进行裁减 
        if self.train and self.patch:
            img, point = random_crop(img, point) #随机裁减出4128*128*3的区域,并返回相对于该区域的点坐标(向量)
            for i, _ in enumerate(point):
                point[i] = torch.Tensor(point[i])

        # random flipping 有一半的概率随机水平翻转
        if random.random() > 0.5 and self.train and self.flip:
            # random flip
            img = torch.Tensor(img[:, :, :, ::-1].copy())
            for i, _ in enumerate(point):
                point[i][:, 0] = 128 - point[i][:, 0]

        if not self.train:
            point = [point]

        img = torch.Tensor(img)

        # pack up related infos
        target = [{} for i in range(len(point))]

        for i, _ in enumerate(point):
            target[i]['point'] = torch.Tensor(point[i])

            image_id = int(img_path.split('/')[-1].split('.')[0].split('_')[-1])
            image_id = torch.Tensor([image_id]).long()

            target[i]['image_id'] = image_id
            target[i]['labels'] = torch.ones([point[i].shape[0]]).long()

        return img, target #返回的是图片矩阵和一个target字典(包含point,imgid,label(1*N))

3.2 装载数据

def load_data(img_gt_path, train):
    img_path, gt_path = img_gt_path
    # load the images
    img = cv2.imread(img_path)
    img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))  #用imread读取的图片是BGR颜色的,需要转换成RGB,不然会偏蓝
    # load ground truth points
    points = []
    with open(gt_path) as f_label:
        for line in f_label:
            x = float(line.strip().split(' ')[0])
            y = float(line.strip().split(' ')[1])
            points.append([x, y])

    return img, np.array(points)

3.3 随机裁减增强

# random crop augumentation
def random_crop(img, den, num_patch=4):
    half_h = 128
    half_w = 128

    result_img = np.zeros([num_patch, img.shape[0], half_h, half_w]) 

    result_den = []

    # crop num_patch for each image 随即裁减4个128*128*3的区域
    for i in range(num_patch):
        start_h = random.randint(0, img.size(1) - half_h) #返回 [0,img.size(1) - half_h] 之间的任意整数
        start_w = random.randint(0, img.size(2) - half_w)

        end_h = start_h + half_h
        end_w = start_w + half_w

        # copy the cropped rect 复制裁减区域
        result_img[i] = img[:, start_h:end_h, start_w:end_w]

        # copy the cropped points 复制裁减区域的点
        #idx 为一个True和false的数组
        idx = (den[:, 0] >= start_w) & (den[:, 0] <= end_w) & (den[:, 1] >= start_h) & (den[:, 1] <= end_h)

        # shift the corrdinates 更正坐标点位置
        record_den = den[idx]
        record_den[:, 0] -= start_w
        record_den[:, 1] -= start_h

        #添加到result中
        result_den.append(record_den)
    #返回裁减的4张图片的矩阵和坐标点
    return result_img, result_den

四、网络训练

4.1 训练模块

# create the P2PNet model
def build(args, training):
    # treats persons as a single class
    num_classes = 1

    backbone = build_backbone(args)
    model = P2PNet(backbone, args.row, args.line)
    
    #如果是测试的话直接返回模型
    if not training: 
        return model

    weight_dict = {'loss_ce': 1, 'loss_points': args.point_loss_coef}

    losses = ['labels', 'points']

    #返回预测点和真实点匹配的索引
    matcher = build_matcher_crowd(args)

    #计算损失
    criterion = SetCriterion_Crowd(num_classes, \
                                matcher=matcher, weight_dict=weight_dict, \
                                eos_coef=args.eos_coef, losses=losses)

    #返回模型和损失
    return model, criterion

4.2 匹配器

class HungarianMatcher_Crowd(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    # 训练时初始化权重 cost_class = 1, cost_point=0.05
    def __init__(self, cost_class: float = 1, cost_point: float = 3):
        """Creates the matcher

        Params:
            cost_class: This is the relative weight of the foreground object 前景物体的权重
            cost_point: This is the relative weight of the L1 error of the points coordinates in the matching cost 匹配过程中点与点之间的L1误差
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_point = cost_point
        assert cost_class != 0 or cost_point != 0, "all costs cant be 0"

    @torch.no_grad()
    def forward(self, outputs, targets):
        """ Performs the matching

        Params:
            outputs: This is a dict that contains at least these entries:
                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                 "points": Tensor of dim [batch_size, num_queries, 2] with the predicted point coordinates

            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "labels": Tensor of dim [num_target_points] (where num_target_points is the number of ground-truth
                           objects in the target) containing the class labels
                 "points": Tensor of dim [num_target_points, 2] containing the target point coordinates

        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_points)
        """
        bs, num_queries = outputs["pred_logits"].shape[:2]

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_points = outputs["pred_points"].flatten(0, 1)  # [batch_size * num_queries, 2]

        # Also concat the target labels and points
        # tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_points = torch.cat([v["point"] for v in targets])

        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        cost_class = -out_prob[:, tgt_ids]

        # Compute the L2 cost between point

        #torch.cdist(x1, x2, p=2.0, compute_mode=‘use_mm_for_euclid_dist_if_necessary’)计算两组输入的每对点之间的距离
        #x1 (Tensor) – input tensor of shape B×P×M .
        #x2 (Tensor) – input tensor of shape B×R×M .
        #output (Tensor) – will have shape B×P×R
        #p=2 means L2 loss
        cost_point = torch.cdist(out_points, tgt_points, p=2)

        # Compute the giou cost between point

        # Final cost matrix 计算成本矩阵用于匈牙利算法的匹配
        #C=0.05*cost_point + 1 * cost_class 
        C = self.cost_point * cost_point + self.cost_class * cost_class
        C = C.view(bs, num_queries, -1).cpu() #view成[bs,num_queries,-1]的格式

        #获取每一个真实标签的数量
        sizes = [len(v["point"]) for v in targets]

        #匈牙利算法
        #scipy.optimize.linear_sum_assignment(cost_matrix,maximize=False) 解决线性和分配问题,这里有不同图像的成本矩阵
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]

        #最后返回预测点和对应的真实点的索引值
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]


def build_matcher_crowd(args):
    return HungarianMatcher_Crowd(cost_class=args.set_cost_class, cost_point=args.set_cost_point)

4.3 损失函数

class SetCriterion_Crowd(nn.Module):

    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        #    num_classes = 1
        #    weight_dict = {'loss_ce': 1, 'loss_points': args.point_loss_coef}
        #    eos_coef = 0.5
        #    losses = ['labels', 'points']

        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses

        #empty_weight = [1
        #                1]
        empty_weight = torch.ones(self.num_classes + 1)
        
        #empty_weight = [0.5
        #                1]
        empty_weight[0] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)

    def loss_labels(self, outputs, targets, indices, num_points):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)

        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], 0,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}

        return losses

    def loss_points(self, outputs, targets, indices, num_points):

        assert 'pred_points' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_points = outputs['pred_points'][idx]
        target_points = torch.cat([t['point'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = F.mse_loss(src_points, target_points, reduction='none')

        losses = {}
        losses['loss_point'] = loss_bbox.sum() / num_points

        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_points, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'points': self.loss_points,
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_points, **kwargs)

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        #得到output
        output1 = {'pred_logits': outputs['pred_logits'], 'pred_points': outputs['pred_points']}

        #匈牙利算法匹配的结果索引
        indices1 = self.matcher(output1, targets)

        #真实点的数量
        num_points = sum(len(t["labels"]) for t in targets)
        #转为tensor
        num_points = torch.as_tensor([num_points], dtype=torch.float, device=next(iter(output1.values())).device)

        #单机多卡的训练
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_points)
        num_boxes = torch.clamp(num_points / get_world_size(), min=1).item()

        losses = {}
        #更新loss
        for loss in self.losses:
            losses.update(self.get_loss(loss, output1, targets, indices1, num_boxes))

        return losses

五、预测

def get_args_parser():
    parser = argparse.ArgumentParser('Set parameters for P2PNet evaluation', add_help=False)
    
    # * Backbone
    parser.add_argument('--backbone', default='vgg16_bn', type=str,
                        help="name of the convolutional backbone to use")

    parser.add_argument('--row', default=2, type=int,
                        help="row number of anchor points")
    parser.add_argument('--line', default=2, type=int,
                        help="line number of anchor points")

    parser.add_argument('--output_dir', default='output/',
                        help='path where to save')
    parser.add_argument('--weight_path', default='weights/best_mae.pth',
                        help='path where the trained weights saved')

    parser.add_argument('--gpu_id', default=-1, type=int, help='the gpu used for evaluation')

    return parser
    

def main(args, debug=False):

    os.environ["CUDA_VISIBLE_DEVICES"] = '{}'.format(args.gpu_id)

    print(args)
    device = torch.device('cpu')
    # get the P2PNet
    model = build_model(args)
    # move to GPU
    model.to(device)
    # load trained model
    if args.weight_path is not None:
        checkpoint = torch.load(args.weight_path, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
    # convert to eval mode
    model.eval()
    # create the pre-processing transform
    transform = standard_transforms.Compose([
        standard_transforms.ToTensor(), 
        standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # set your image path here
    img_path = "vis/test1/"

    imgsfile = os.listdir(img_path)

    os.makedirs(args.output_dir+img_path,exist_ok=True)

    for imgname in imgsfile:
        # load the images
        #img_raw = Image.open(img_path).convert('RGB')

        start =time.time()
    
        img_raw = Image.open(img_path+imgname).convert('RGB')
        H,W = img_raw.size
        
        """
        if H>1000 and H <=1500:
            if W >1000 and W<=1500:
                img_raw = img_raw.resize((int(H/2), int(W/2)), Image.ANTIALIAS)
        elif H >=2000 and H <=3500:
            if W >=2000 and W<=3500:
                img_raw = img_raw.resize((int(H/4), int(W/4)), Image.ANTIALIAS)
        elif H >3500 or W >3500:
            img_raw = img_raw.resize((int(H/4), int(W/4)), Image.ANTIALIAS)
        """
        #img_raw = img_raw.resize((1024, 1024), Image.ANTIALIAS)
        #img_raw = Image.open(img_path).convert('RGB')
        #print(img_raw.size)


        # round the size将图片的宽高resize成128的倍数
        width, height = img_raw.size
        new_width = width // 128 * 128
        new_height = height // 128 * 128
        img_raw = img_raw.resize((new_width, new_height), Image.ANTIALIAS)

        #print(img_raw.size)
        # pre-proccessing预处理、归一化
        img = transform(img_raw)

        samples = torch.Tensor(img).unsqueeze(0)
        samples = samples.to(device)

        # run inference
        outputs = model(samples)
        outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]

        outputs_points = outputs['pred_points'][0]

        threshold = 0.5 #置信度
        # filter the predictions
        points = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist()
        predict_cnt = int((outputs_scores > threshold).sum()) #预测数量

        outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]

        outputs_points = outputs['pred_points'][0]

        # draw the predictions
        size = 2
        img_to_draw = cv2.cvtColor(np.array(img_raw), cv2.COLOR_RGB2BGR)
        for p in points:
            img_to_draw = cv2.circle(img_to_draw, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)

        # save the visualized image
        cv2.imwrite(os.path.join(args.output_dir,img_path+imgname+'_pred{}.jpg'.format(predict_cnt)), img_to_draw)

        end = time.time()

        print(imgname+" inference time:",end - start)
    
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser('P2PNet evaluation script', parents=[get_args_parser()])
    args = parser.parse_args()
    main(args)

在这里插入图片描述

  • 12
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 26
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值