【无标题】

FGVP: Fine-Grained Visual Prompting

代码解析

import json
import os
from typing import List

import clip
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from segment_anything import sam_model_registry
from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator
from segment_anything.utils.amg import (batched_mask_to_box,
                                        remove_small_regions)
from segment_anything.utils.transforms import ResizeLongestSide
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from fine_grained_visual_prompt import FGVP_ENSEMBLE


class ClipModel(nn.Module):
    def __init__(self, model: nn.Module, tokenizer, device):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.device = device

    @torch.no_grad()
    def forward(self, image: torch.Tensor, text: List[str], softmax=False):
        text = [t.lower() for t in text]
        tokenized_text = self.tokenizer(text).to(self.device)
        image_features = self.model.encode_image(image)
        text_features = self.model.encode_text(tokenized_text)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        similarity = image_features @ text_features.t()  # N, M (0~1)
        if softmax:
            similarity = (100 * similarity).softmax(-1)
        return similarity


def draw_box(image, bbox, color):
    # color: bgr
    # bbox: [x1, y1, x2, y2]
    thickness = 2

    image = cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color=color, thickness=thickness)
    return image


def draw_mask(image, mask, color):
    # color: bgr
    # mask: [h, w]
    alpha = 0.3
    draw_contours = True
    coutour_thickness = 2

    mask = mask > 0
    contours, hierarchy = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    overlay = image.copy()
    overlay[mask] = color
    image = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0.0)
    if draw_contours:
        image = cv2.drawContours(image, contours, -1, (255, 255, 255), coutour_thickness)

    return image


def draw_text(image, text, left, top, color, as_center=False):
    # color: bgr
    font_face = cv2.FONT_HERSHEY_SIMPLEX
    line_type = cv2.LINE_AA
    font_scale = 0.8
    thickness = 1
    bkg_alpha = 0.6
    top_relaxation = 8
    bottom_relaxation = 2

    # calculate text width and height
    retval, baseLine = cv2.getTextSize(text, fontFace=font_face, fontScale=font_scale, thickness=thickness)

    # calculate text location
    top = max(top, retval[1] + baseLine + top_relaxation)
    lt = [left, top - retval[1] - baseLine - top_relaxation]
    rb = [left + retval[0], top - bottom_relaxation]
    text_lt = [left, top - baseLine]

    if as_center:
        shift_x = (rb[0] - lt[0]) // 2
        shift_y = (rb[1] - lt[1]) // 2
        lt[0] -= shift_x
        rb[0] -= shift_x
        lt[1] -= shift_y
        rb[1] -= shift_y
        text_lt = [left - shift_x, top - baseLine - shift_y]

    overlay = image.copy()
    overlay = cv2.rectangle(overlay, lt, rb, thickness=-1, color=[0, 0, 0])
    image = cv2.addWeighted(overlay, bkg_alpha, image, 1 - bkg_alpha, 0.0)
    image = cv2.putText(image, text, text_lt, fontScale=font_scale, fontFace=font_face,
                        color=color, thickness=thickness, lineType=line_type)
    return image


def get_masks(args, image):
    if sam_prompt == "box" or sam_prompt == "keypoint":
        ori_size = image.shape[:2]
        image = resize_transform_sam.apply_image(image)
        new_size = image.shape[:2]

        sam_inputs = torch.as_tensor(image, device=device)
        sam_inputs = sam_inputs.permute(2, 0, 1).contiguous()[None, :, :, :]
        sam_inputs = sam_model.preprocess(sam_inputs)

        if sam_prompt == "keypoint":
            with open(candidate_points, 'r') as f:
                points = json.load(f)
            in_points = resize_transform_sam.apply_coords_torch(points, ori_size)
            in_points = in_points.to(device)
            in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=device)
            in_points = (in_points[:, None, :], in_labels[:, None])
        else:
            in_points = None

        if sam_prompt == "box":
            with open(candidate_boxes, 'r') as f:
                boxes = torch.from_numpy(np.array(json.load(f)))
            in_boxes = resize_transform_sam.apply_boxes_torch(boxes, ori_size)
            in_boxes = in_boxes[:, None, :].to(device)
        else:
            in_boxes = None

        features = sam_model.image_encoder(sam_inputs)
        sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
            points=in_points,
            boxes=in_boxes,
            masks=None,
        )
        low_res_masks, iou_pred = sam_model.mask_decoder(
            image_embeddings=features,
            image_pe=sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=sam_multimask_output,
        )
        masks = F.interpolate(low_res_masks, (sam_image_size, sam_image_size),
                              mode="bilinear", align_corners=False)  # N, 1, H, W
        masks = masks[:, :1, :new_size[0], :new_size[1]]
        masks = F.interpolate(masks, ori_size, mode="bilinear", align_corners=False)
        masks = masks > sam_model.mask_threshold

        if min_mask_region_area > 0:
            bit_masks = masks > sam_model.mask_threshold
            masks = []
            for mask in bit_masks:
                mask = mask.squeeze(0).cpu().numpy()
                mask, changed = remove_small_regions(mask, min_mask_region_area, mode="holes")
                mask, changed = remove_small_regions(mask, min_mask_region_area, mode="islands")
                mask = torch.as_tensor(mask, device=device).unsqueeze(0)
                masks.append(mask)
            masks = torch.stack(masks, 0)
        if recompute_box:
            masks = masks > sam_model.mask_threshold
            boxes = batched_mask_to_box(bit_masks.squeeze(1)).float()
    else:
        assert sam_prompt == 'grid'
        """ 
        {'segmentation': array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]]), 
       'area': 604, 
       'bbox': [338, 195, 15, 58], 
       'predicted_iou': 0.9510682225227356, 
       'point_coords': [[340.0, 199.6875]], 
       'stability_score': 0.9917627573013306, 
       'crop_box': [0, 0, 640, 426]}
        """
        # len 107 
        outputs = sam_mask_generator.generate(image)
        masks = torch.from_numpy(np.stack([x['segmentation'] for x in outputs])).unsqueeze(1)
        boxes = torch.from_numpy(np.stack([x['bbox'] for x in outputs])).float()
        boxes[:, 2] += boxes[:, 0]
        boxes[:, 3] += boxes[:, 1]

    return boxes, masks


if __name__ == "__main__":
    
    # ================== 参数 ==================
    # 把上面的参数都放到这里,直接变成变量,值使用默认值
    img_dir = '/root/VLMPT/FGVP/demo/exp2/ori.png'
    text = ['photo on the wall']
    candidate_boxes = None
    candidate_points = None
    out_dir = '/root/VLMPT/FGVP/demo/exp2'
    visual_prompt = ['blur_mask']
    expand_ratio = 0.01
    recompute_box = False
    color_line = 'red'
    color_mask = 'green'
    thickness = 2
    alpha = 0.5
    blur_std_dev = 100
    contour_scale = 1.0
    clip_model = 'ViT-L/14@336px'
    clip_pretrained = ''
    clip_image_size = 336
    clip_processing = 'resize'
    clip_crop_pct = 1.0
    sam_prompt = 'grid'
    sam_model = 'vit_h'
    sam_pretrained = '/root/autodl-tmp/CLIP/SAM/vit_h/sam_vit_h_4b8939.pth'
    sam_image_size = 1024
    min_mask_region_area = 400
    sam_multimask_output = False
    sam_neg_label = False
    points_per_side = 16
    points_per_batch = 256
    pred_iou_thresh = 0.86
    stability_score_thresh = 0.92
    stability_score_offset = 0.7
    box_nms_thresh = 0.7
    crop_n_layers = 0
    crop_nms_thresh = 0.7
    crop_overlap_ratio = 512 / 1500
    crop_n_points_downscale_factor = 2
    point_grids = None
    output_mode = 'binary_mask'
    filter_mask_thr = 0.0
    # ================== 参数 ==================
    
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    os.makedirs(out_dir, exist_ok=True)
    print("loading CLIP")
    import time 
    start = time.time()
    # <segment_anything.utils.transforms.ResizeLongestSide object at 0x7f4f0a072c10>
    # resize_transform_clip.target_length=336
    resize_transform_clip = ResizeLongestSide(clip_image_size)
    """  
    (visual): VisionTransformer
    (transformer): Transformer
    """
    encoder, _ = clip.load(clip_model, download_root=clip_pretrained, device=device)
    tokenizer = clip.tokenize
    clip_model = ClipModel(encoder, tokenizer, device)
    end1 = time.time()
    print(f"loading CLIP time: {end1 - start:.2f}s")

    print("loading SAM")
    end2 = time.time()
    # 1024
    sam_image_size = sam_image_size
    # <segment_anything.utils.transforms.ResizeLongestSide object at 0x7f4f09a1c670>
    # resize_transform_sam.target_length=1024
    resize_transform_sam = ResizeLongestSide(sam_image_size)
    # (image_encoder): ImageEncoderViT
    # (prompt_encoder): PromptEncoder
    # (mask_decoder): MaskDecoder
    sam_model = sam_model_registry[sam_model](checkpoint=sam_pretrained).to(device)
    # <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x7f4f0a54e910>
    sam_mask_generator = SamAutomaticMaskGenerator(
        sam_model,
        points_per_side=points_per_side,
        points_per_batch=points_per_batch,
        pred_iou_thresh=pred_iou_thresh,
        stability_score_thresh=stability_score_thresh,
        stability_score_offset=stability_score_offset,
        box_nms_thresh=box_nms_thresh,
        crop_n_layers=crop_n_layers,
        crop_nms_thresh=crop_nms_thresh,
        crop_overlap_ratio=crop_overlap_ratio,
        crop_n_points_downscale_factor=crop_n_points_downscale_factor,
        point_grids=point_grids,
        min_mask_region_area=min_mask_region_area,
        output_mode=output_mode,
    )
    end3 = time.time()
    print(f"loading SAM time: {end3 - end2:.2f}s")

    print("loading PROMPT")
    end4 = time.time()
    # (0.485, 0.456, 0.406)
    '''  
    tensor([[[123.6750]],

        [[116.2800]],

        [[103.5300]]], device='cuda:0')
    '''
    # (0.229, 0.224, 0.225)
    '''  
    tensor([[[58.3950]],

        [[57.1200]],

        [[57.3750]]], device='cuda:0')
    '''
    pixel_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).view(-1, 1, 1).to(device) * 255.0
    pixel_std = torch.tensor(IMAGENET_DEFAULT_STD).view(-1, 1, 1).to(device) * 255.0
    fgvp = FGVP_ENSEMBLE(
        color_line=color_line,
        thickness=thickness,
        color_mask=color_mask,
        alpha=alpha,
        clip_processing=clip_processing,
        clip_image_size=clip_image_size,
        resize_transform_clip=resize_transform_clip,
        pixel_mean=pixel_mean,
        pixel_std=pixel_std,
        blur_std_dev=blur_std_dev,
        mask_threshold=sam_model.mask_threshold,
        contour_scale=contour_scale,
        device=device,
    )
    end5 = time.time()
    print(f"loading PROMPT time: {end5 - end4:.2f}s")

    # load image
    end6 = time.time()
    image = cv2.imread(img_dir)
    # (426, 640)
    real_size = image.shape[:2]
    end7 = time.time()
    print(f"loading image time: {end7 - end6:.2f}s")

    end8 = time.time()
    args = None
    '''  
    tensor([[338., 195., 353., 253.],
        [  0.,  77.,  29., 238.],
        [ 84., 327., 187., 425.],
        [  0., 231.,  76., 285.],
        [  0., 249., 133., 424.],
        [  0.,  32.,  27., 138.],
        [492.,  44., 532., 150.],
        [413.,   2., 434.,  21.],
        [  0., 249.,  66., 357.],
        [591., 191., 639., 222.],
        [487., 178., 524., 301.],
        [ 48., 114.,  86., 163.],
        [489.,  81., 512., 120.],
        [144., 200., 189., 279.],
        [538., 269., 639., 377.],
        [ 93., 289., 120., 339.],
        [106., 178., 157., 240.],
        [380.,  99., 401., 147.],
        [398., 268., 484., 323.],
        [375.,  46., 416., 147.],
        [111., 327., 182., 405.],
        [491.,  44., 532.,  99.],
        [417.,  61., 453.,  91.],
        [527.,  51., 564., 303.],
        [383., 175., 524., 302.],
        [607., 234., 639., 282.],
        [ 73., 241., 174., 334.],
        [375.,  46., 417.,  99.],
        [534.,  51., 564., 254.],
        [ 69.,   0., 147., 181.],
        [411., 169., 479., 268.],
        [602.,   0., 639., 227.],
        [235., 361., 327., 425.],
        [194.,  24., 278., 140.],
        [294.,  84., 332., 130.],
        [208.,  43., 225., 124.],
        [289., 148., 341., 176.],
        [ 46., 285.,  96., 339.],
        [538., 269., 639., 318.],
        [412., 253., 475., 268.],
        [135., 133., 211., 187.],
        [158., 301., 230., 419.],
        [  0.,   0.,  76., 116.],
        [162., 178., 192., 202.],
        [489., 138., 542., 171.],
        [412., 170., 478., 184.],
        [412., 184., 478., 254.],
        [398., 136., 489., 323.],
        [240., 164., 276., 202.],
        [574., 222., 639., 281.],
        [207.,  40., 264., 125.],
        [  0., 350.,  35., 407.],
        [207., 211., 269., 312.],
        [ 84.,  42., 138., 119.],
        [373.,  41., 385.,  54.],
        [292., 279., 307., 293.],
        [401., 301., 423., 310.],
        [262., 293., 402., 343.],
        [188., 197., 292., 288.],
        [ 69.,   0., 207., 181.],
        [ 87., 268., 103., 288.],
        [134., 290., 155., 313.],
        [271., 219., 342., 294.],
        [453.,  55., 497., 133.],
        [  0., 232., 118., 357.],
        [260., 262., 275., 282.],
        [ 48., 310.,  67., 336.],
        [527., 138., 538., 150.],
        [175., 272., 277., 345.],
        [496., 303., 506., 318.],
        [530.,  35., 547.,  51.],
        [287., 321., 302., 349.],
        [484., 329., 526., 349.],
        [300.,   0., 310.,  82.],
        [489., 137., 511., 150.],
        [375.,  46., 453.,  99.],
        [ 91., 185., 116., 215.],
        [413., 114., 427., 142.],
        [457., 109., 463., 123.],
        [290., 361., 596., 425.],
        [363., 311., 639., 424.],
        [  0.,   2., 101., 271.],
        [265., 202., 346., 370.],
        [217., 280., 639., 424.],
        [423., 303., 462., 315.],
        [  0., 338., 133., 425.],
        [ 26., 101., 102., 278.],
        [ 84., 380., 187., 425.],
        [453.,   0., 469.,  19.],
        [262., 293., 401., 425.],
        [149., 269., 327., 425.],
        [ 69.,   0., 618., 240.],
        [ 85., 328., 596., 425.],
        [348.,  15., 576., 304.],
        [255., 139., 277., 164.],
        [ 69.,   0., 388., 196.],
        [  0.,   0.,  76., 238.],
        [470.,   0., 618., 237.],
        [255., 145., 270., 163.],
        [457., 101., 463., 123.],
        [150., 269., 239., 423.],
        [ 91.,  49., 132., 119.],
        [537., 302., 547., 346.],
        [107., 243., 144., 257.],
        [618., 327., 635., 378.],
        [496., 295., 507., 318.],
        [213.,   0., 226.,  27.]])
    '''
    '''  
    tensor([[[[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]]],


        [[[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]]],


        [[[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]]],


        ...,


        [[[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]]],


        [[[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]]],


        [[[False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          ...,
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False],
          [False, False, False,  ..., False, False, False]]]])
    '''
    # torch.Size([107, 4])
    # torch.Size([107, 1, 426, 640])
    boxes, masks = get_masks(args, image)
    end9 = time.time()
    print(f"get masks time: {end9 - end8:.2f}s")
    centers = torch.stack([boxes[:, 0::2].mean(1), boxes[:, 1::2].mean(1)], 1)

    res = image.copy()
    for box, mask in zip(boxes, masks):
        # array([338, 195, 353, 253])
        box = box.detach().cpu().numpy().astype(int)
        # (426, 640)
        mask = mask.squeeze().detach().cpu().numpy()
        res = draw_box(res, box, (0, 255, 0))
        res = draw_mask(res, mask, (0, 255, 0))
    cv2.imwrite(os.path.join(out_dir, "candidates.jpg"), res)

    text_inputs = [f"a photo of a {t}" for t in text]
    # torch.Size([107, 3, 336, 336])
    clip_inputs = torch.cat([fgvp(vp, image[:, :, ::-1], centers, boxes, masks) for vp in visual_prompt])

    # os.makedirs(os.path.join(out_dir, 'fgvp'), exist_ok=True)
    # for idx, clip_input in enumerate(clip_inputs):
    #     clip_input = clip_input * pixel_std + pixel_mean
    #     clip_input = clip_input.permute(1, 2, 0).cpu().numpy().astype(np.uint8)[..., ::-1]
    #     cv2.imwrite(os.path.join(out_dir, f"fgvp/{idx}.jpg"), clip_input)

    end10 = time.time()
    # torch.Size([107, 1])
    '''  
    tensor([[0.1727],
        [0.1602],
        [0.1897],
        [0.1495],
        [0.1699],
        [0.1674],
        [0.1772],
        [0.1818],
        [0.1854],
        [0.1613],
        [0.1852],
        [0.1753],
        [0.1937],
        [0.2128],
        [0.1696],
        [0.1676],
        [0.1700],
        [0.1527],
        [0.1765],
        [0.1923],
        [0.1746],
        [0.1829],
        [0.2363],
        [0.1586],
        [0.1869],
        [0.1709],
        [0.1648],
        [0.1886],
        [0.1627],
        [0.2219],
        [0.2025],
        [0.1833],
        [0.1614],
        [0.2544],
        [0.2332],
        [0.2015],
        [0.2114],
        [0.2126],
        [0.1688],
        [0.1761],
        [0.1741],
        [0.1748],
        [0.1866],
        [0.1820],
        [0.1730],
        [0.1859],
        [0.2030],
        [0.1993],
        [0.1868],
        [0.1561],
        [0.2300],
        [0.1808],
        [0.1710],
        [0.2117],
        [0.1836],
        [0.1490],
        [0.1708],
        [0.1792],
        [0.1803],
        [0.2128],
        [0.1753],
        [0.1703],
        [0.1860],
        [0.1849],
        [0.1765],
        [0.1646],
        [0.1763],
        [0.1794],
        [0.1821],
        [0.1802],
        [0.1843],
        [0.1666],
        [0.1646],
        [0.1971],
        [0.1659],
        [0.2128],
        [0.1714],
        [0.1729],
        [0.1689],
        [0.1970],
        [0.1733],
        [0.1980],
        [0.1871],
        [0.1863],
        [0.1749],
        [0.1672],
        [0.1686],
        [0.1951],
        [0.1862],
        [0.1798],
        [0.1667],
        [0.2312],
        [0.2043],
        [0.2250],
        [0.1881],
        [0.2422],
        [0.1824],
        [0.2076],
        [0.1792],
        [0.1775],
        [0.1840],
        [0.2229],
        [0.1613],
        [0.1648],
        [0.1724],
        [0.1715],
        [0.1959]], device='cuda:0', dtype=torch.float16)
    '''
    logits_per_image = clip_model(clip_inputs, text_inputs)
    end11 = time.time()
    print(f"CLIP time: {end11 - end10:.2f}s")
    # torch.Size([107, 1])
    ''' 
    tensor([[0.1727],
        [0.1602],
        [0.1897],
        [0.1495],
        [0.1699],
        [0.1674],
        [0.1772],
        [0.1818],
        [0.1854],
        [0.1613],
        [0.1852],
        [0.1753],
        [0.1937],
        [0.2128],
        [0.1696],
        [0.1676],
        [0.1700],
        [0.1527],
        [0.1765],
        [0.1923],
        [0.1746],
        [0.1829],
        [0.2363],
        [0.1586],
        [0.1869],
        [0.1709],
        [0.1648],
        [0.1886],
        [0.1627],
        [0.2219],
        [0.2025],
        [0.1833],
        [0.1614],
        [0.2544],
        [0.2332],
        [0.2015],
        [0.2114],
        [0.2126],
        [0.1688],
        [0.1761],
        [0.1741],
        [0.1748],
        [0.1866],
        [0.1820],
        [0.1730],
        [0.1859],
        [0.2030],
        [0.1993],
        [0.1868],
        [0.1561],
        [0.2300],
        [0.1808],
        [0.1710],
        [0.2117],
        [0.1836],
        [0.1490],
        [0.1708],
        [0.1792],
        [0.1803],
        [0.2128],
        [0.1753],
        [0.1703],
        [0.1860],
        [0.1849],
        [0.1765],
        [0.1646],
        [0.1763],
        [0.1794],
        [0.1821],
        [0.1802],
        [0.1843],
        [0.1666],
        [0.1646],
        [0.1971],
        [0.1659],
        [0.2128],
        [0.1714],
        [0.1729],
        [0.1689],
        [0.1970],
        [0.1733],
        [0.1980],
        [0.1871],
        [0.1863],
        [0.1749],
        [0.1672],
        [0.1686],
        [0.1951],
        [0.1862],
        [0.1798],
        [0.1667],
        [0.2312],
        [0.2043],
        [0.2250],
        [0.1881],
        [0.2422],
        [0.1824],
        [0.2076],
        [0.1792],
        [0.1775],
        [0.1840],
        [0.2229],
        [0.1613],
        [0.1648],
        [0.1724],
        [0.1715],
        [0.1959]], device='cuda:0', dtype=torch.float16)
    '''
    logits_per_image = logits_per_image.view(len(visual_prompt), len(boxes), len(text_inputs)).mean(0)

    # scores, row_inds = logits_per_image.topk(1, dim=1)
    # 1 107 
    N, M = logits_per_image.shape
    ''' 
    tensor([[0.8413],
        [0.8521],
        [0.8271],
        [0.8613],
        [0.8438],
        [0.8457],
        [0.8374],
        [0.8340],
        [0.8306],
        [0.8511],
        [0.8311],
        [0.8394],
        [0.8237],
        [0.8081],
        [0.8442],
        [0.8457],
        [0.8438],
        [0.8584],
        [0.8384],
        [0.8252],
        [0.8398],
        [0.8330],
        [0.7896],
        [0.8535],
        [0.8296],
        [0.8428],
        [0.8481],
        [0.8281],
        [0.8496],
        [0.8008],
        [0.8169],
        [0.8325],
        [0.8511],
        [0.7754],
        [0.7920],
        [0.8174],
        [0.8096],
        [0.8086],
        [0.8447],
        [0.8384],
        [0.8403],
        [0.8398],
        [0.8296],
        [0.8335],
        [0.8413],
        [0.8306],
        [0.8164],
        [0.8193],
        [0.8296],
        [0.8555],
        [0.7944],
        [0.8345],
        [0.8428],
        [0.8091],
        [0.8320],
        [0.8613],
        [0.8428],
        [0.8359],
        [0.8350],
        [0.8081],
        [0.8394],
        [0.8433],
        [0.8301],
        [0.8311],
        [0.8384],
        [0.8481],
        [0.8384],
        [0.8359],
        [0.8335],
        [0.8350],
        [0.8315],
        [0.8467],
        [0.8481],
        [0.8213],
        [0.8472],
        [0.8081],
        [0.8423],
        [0.8413],
        [0.8447],
        [0.8213],
        [0.8408],
        [0.8203],
        [0.8291],
        [0.8301],
        [0.8394],
        [0.8462],
        [0.8447],
        [0.8228],
        [0.8301],
        [0.8354],
        [0.8462],
        [0.7935],
        [0.8149],
        [0.7983],
        [0.8286],
        [0.7847],
        [0.8335],
        [0.8125],
        [0.8359],
        [0.8374],
        [0.8320],
        [0.8003],
        [0.8511],
        [0.8481],
        [0.8418],
        [0.8423],
        [0.8223]], device='cuda:0', dtype=torch.float16)
    '''
    # torch.Size([107, 1])
    cost = torch.exp(-logits_per_image)
    # array([33])   array([0])
    row_inds, col_inds = linear_sum_assignment(cost.cpu().numpy())

    res = image.copy()
    print(boxes[row_inds])
    # tensor([[194.,  24., 278., 140.]])
    for row_ind, col_ind in zip(row_inds, col_inds):
        # array([194,  24, 278, 140])
        box = boxes[row_ind].detach().cpu().numpy().astype(int)
        # (426, 640)
        mask = masks[row_ind].squeeze().detach().cpu().numpy()
        # 'photo on the wall'
        caption = text[col_ind]
        res = draw_box(res, box, (0, 255, 0))
        res = draw_mask(res, mask, (0, 255, 0))
        res = draw_text(res, caption, box[0], box[1], (0, 255, 0))
    cv2.imwrite(os.path.join(out_dir, "res_pre.jpg"), res)
    
    

ResizeLongestSide

这段代码实现了一个名为ResizeLongestSide的类,用于调整图像、坐标和框的尺寸。它提供了处理NumPy数组和批处理的Torch张量的方法。

下面是对代码的逐行解释:

class ResizeLongestSide:
    def __init__(self, target_length: int) -> None:
        self.target_length = target_length

这个类的构造函数接受一个参数target_length,表示目标尺寸的最长边长度。

    def apply_image(self, image: np.ndarray) -> np.ndarray:
        target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
        return np.array(resize(to_pil_image(image), target_size))

apply_image方法用于调整图像的尺寸。它接受一个NumPy数组表示的图像,并使用get_preprocess_shape方法计算目标尺寸,然后使用resize函数将图像调整为目标尺寸。

    def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
        old_h, old_w = original_size
        new_h, new_w = self.get_preprocess_shape(
            original_size[0], original_size[1], self.target_length
        )
        coords = deepcopy(coords).astype(float)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords

apply_coords方法用于调整坐标的尺寸。它接受一个表示坐标的NumPy数组和原始图像的尺寸。根据原始图像和目标尺寸的比例,它将坐标进行缩放调整。

    def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
        boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
        return boxes.reshape(-1, 4)

apply_boxes方法用于调整框的尺寸。它接受一个表示框的NumPy数组和原始图像的尺寸。它首先将框的坐标形状调整为(-1, 2, 2),然后使用apply_coords方法调整坐标的尺寸,最后将形状调整为(-1, 4)。

    def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
        target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
        return F.interpolate(
            image, target_size, mode="bilinear", align_corners=False, antialias=True
        )

apply_image_torch方法用于处理批处理的Torch张量形式的图像。它使用get_preprocess_shape方法计算目标尺寸,并使用F.interpolate函数进行双线性插值,将图像调整为目标尺寸。

    def apply_coords_torch(
        self, coords: torch.Tensor, original_size: Tuple[int, ...]
    ) -> torch.Tensor:
        old_h, old_w = original_size
        new_h, new_w = self.get_preprocess_shape(
            original_size[0], original_size[1], self.target_length
        )
        coords = deepcopy(coords).to(torch.float)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords

apply_coords_torch方法用于处理批处理的Torch张量形式的坐标。它接受一个表示坐标的Torch张量和原始图像的尺寸。根据原始图像和目标尺寸的比例,它将坐标进行缩放调整。

    def apply_boxes_torch(
        self, boxes: torch.Tensor, original_size: Tuple[int, ...]
    ) -> torch.Tensor:
        boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
        return boxes.reshape(-1, 4)

apply_boxes_torch方法用于处理批处理的Torch张量形式的框。它首先将框的坐标形状调整为(-1, 2, 2),然后使用apply_coords_torch方法调整坐标的尺寸,最后将形状调整为(-1, 4)。

    @staticmethod
    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
        scale = long_side_length * 1.0 / max(oldh, oldw)
        newh, neww = oldh * scale, oldw * scale
        neww = int(neww + 0.5)
        newh = int(newh + 0.5)
        return (newh, neww)

get_preprocess_shape是一个静态方法,用于计算给定输入尺寸和目标最长边长度时的输出尺寸。它根据输入尺寸的长宽比例和目标最长边长度的比例,计算出新的高度和宽度。

这段代码实现了图像、坐标和框尺寸的调整功能,可用于在计算机视觉任务中进行预处理或后处理操作。它提供了处理NumPy数组和Torch张量的方法,方便在不同的数据类型上进行操作。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值