FGVP: Fine-Grained Visual Prompting
代码解析
import json
import os
from typing import List
import clip
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from segment_anything import sam_model_registry
from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator
from segment_anything.utils.amg import (batched_mask_to_box,
remove_small_regions)
from segment_anything.utils.transforms import ResizeLongestSide
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from fine_grained_visual_prompt import FGVP_ENSEMBLE
class ClipModel(nn.Module):
def __init__(self, model: nn.Module, tokenizer, device):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.device = device
@torch.no_grad()
def forward(self, image: torch.Tensor, text: List[str], softmax=False):
text = [t.lower() for t in text]
tokenized_text = self.tokenizer(text).to(self.device)
image_features = self.model.encode_image(image)
text_features = self.model.encode_text(tokenized_text)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
similarity = image_features @ text_features.t() # N, M (0~1)
if softmax:
similarity = (100 * similarity).softmax(-1)
return similarity
def draw_box(image, bbox, color):
# color: bgr
# bbox: [x1, y1, x2, y2]
thickness = 2
image = cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color=color, thickness=thickness)
return image
def draw_mask(image, mask, color):
# color: bgr
# mask: [h, w]
alpha = 0.3
draw_contours = True
coutour_thickness = 2
mask = mask > 0
contours, hierarchy = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
overlay = image.copy()
overlay[mask] = color
image = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0.0)
if draw_contours:
image = cv2.drawContours(image, contours, -1, (255, 255, 255), coutour_thickness)
return image
def draw_text(image, text, left, top, color, as_center=False):
# color: bgr
font_face = cv2.FONT_HERSHEY_SIMPLEX
line_type = cv2.LINE_AA
font_scale = 0.8
thickness = 1
bkg_alpha = 0.6
top_relaxation = 8
bottom_relaxation = 2
# calculate text width and height
retval, baseLine = cv2.getTextSize(text, fontFace=font_face, fontScale=font_scale, thickness=thickness)
# calculate text location
top = max(top, retval[1] + baseLine + top_relaxation)
lt = [left, top - retval[1] - baseLine - top_relaxation]
rb = [left + retval[0], top - bottom_relaxation]
text_lt = [left, top - baseLine]
if as_center:
shift_x = (rb[0] - lt[0]) // 2
shift_y = (rb[1] - lt[1]) // 2
lt[0] -= shift_x
rb[0] -= shift_x
lt[1] -= shift_y
rb[1] -= shift_y
text_lt = [left - shift_x, top - baseLine - shift_y]
overlay = image.copy()
overlay = cv2.rectangle(overlay, lt, rb, thickness=-1, color=[0, 0, 0])
image = cv2.addWeighted(overlay, bkg_alpha, image, 1 - bkg_alpha, 0.0)
image = cv2.putText(image, text, text_lt, fontScale=font_scale, fontFace=font_face,
color=color, thickness=thickness, lineType=line_type)
return image
def get_masks(args, image):
if sam_prompt == "box" or sam_prompt == "keypoint":
ori_size = image.shape[:2]
image = resize_transform_sam.apply_image(image)
new_size = image.shape[:2]
sam_inputs = torch.as_tensor(image, device=device)
sam_inputs = sam_inputs.permute(2, 0, 1).contiguous()[None, :, :, :]
sam_inputs = sam_model.preprocess(sam_inputs)
if sam_prompt == "keypoint":
with open(candidate_points, 'r') as f:
points = json.load(f)
in_points = resize_transform_sam.apply_coords_torch(points, ori_size)
in_points = in_points.to(device)
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=device)
in_points = (in_points[:, None, :], in_labels[:, None])
else:
in_points = None
if sam_prompt == "box":
with open(candidate_boxes, 'r') as f:
boxes = torch.from_numpy(np.array(json.load(f)))
in_boxes = resize_transform_sam.apply_boxes_torch(boxes, ori_size)
in_boxes = in_boxes[:, None, :].to(device)
else:
in_boxes = None
features = sam_model.image_encoder(sam_inputs)
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
points=in_points,
boxes=in_boxes,
masks=None,
)
low_res_masks, iou_pred = sam_model.mask_decoder(
image_embeddings=features,
image_pe=sam_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=sam_multimask_output,
)
masks = F.interpolate(low_res_masks, (sam_image_size, sam_image_size),
mode="bilinear", align_corners=False) # N, 1, H, W
masks = masks[:, :1, :new_size[0], :new_size[1]]
masks = F.interpolate(masks, ori_size, mode="bilinear", align_corners=False)
masks = masks > sam_model.mask_threshold
if min_mask_region_area > 0:
bit_masks = masks > sam_model.mask_threshold
masks = []
for mask in bit_masks:
mask = mask.squeeze(0).cpu().numpy()
mask, changed = remove_small_regions(mask, min_mask_region_area, mode="holes")
mask, changed = remove_small_regions(mask, min_mask_region_area, mode="islands")
mask = torch.as_tensor(mask, device=device).unsqueeze(0)
masks.append(mask)
masks = torch.stack(masks, 0)
if recompute_box:
masks = masks > sam_model.mask_threshold
boxes = batched_mask_to_box(bit_masks.squeeze(1)).float()
else:
assert sam_prompt == 'grid'
"""
{'segmentation': array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]),
'area': 604,
'bbox': [338, 195, 15, 58],
'predicted_iou': 0.9510682225227356,
'point_coords': [[340.0, 199.6875]],
'stability_score': 0.9917627573013306,
'crop_box': [0, 0, 640, 426]}
"""
# len 107
outputs = sam_mask_generator.generate(image)
masks = torch.from_numpy(np.stack([x['segmentation'] for x in outputs])).unsqueeze(1)
boxes = torch.from_numpy(np.stack([x['bbox'] for x in outputs])).float()
boxes[:, 2] += boxes[:, 0]
boxes[:, 3] += boxes[:, 1]
return boxes, masks
if __name__ == "__main__":
# ================== 参数 ==================
# 把上面的参数都放到这里,直接变成变量,值使用默认值
img_dir = '/root/VLMPT/FGVP/demo/exp2/ori.png'
text = ['photo on the wall']
candidate_boxes = None
candidate_points = None
out_dir = '/root/VLMPT/FGVP/demo/exp2'
visual_prompt = ['blur_mask']
expand_ratio = 0.01
recompute_box = False
color_line = 'red'
color_mask = 'green'
thickness = 2
alpha = 0.5
blur_std_dev = 100
contour_scale = 1.0
clip_model = 'ViT-L/14@336px'
clip_pretrained = ''
clip_image_size = 336
clip_processing = 'resize'
clip_crop_pct = 1.0
sam_prompt = 'grid'
sam_model = 'vit_h'
sam_pretrained = '/root/autodl-tmp/CLIP/SAM/vit_h/sam_vit_h_4b8939.pth'
sam_image_size = 1024
min_mask_region_area = 400
sam_multimask_output = False
sam_neg_label = False
points_per_side = 16
points_per_batch = 256
pred_iou_thresh = 0.86
stability_score_thresh = 0.92
stability_score_offset = 0.7
box_nms_thresh = 0.7
crop_n_layers = 0
crop_nms_thresh = 0.7
crop_overlap_ratio = 512 / 1500
crop_n_points_downscale_factor = 2
point_grids = None
output_mode = 'binary_mask'
filter_mask_thr = 0.0
# ================== 参数 ==================
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
os.makedirs(out_dir, exist_ok=True)
print("loading CLIP")
import time
start = time.time()
# <segment_anything.utils.transforms.ResizeLongestSide object at 0x7f4f0a072c10>
# resize_transform_clip.target_length=336
resize_transform_clip = ResizeLongestSide(clip_image_size)
"""
(visual): VisionTransformer
(transformer): Transformer
"""
encoder, _ = clip.load(clip_model, download_root=clip_pretrained, device=device)
tokenizer = clip.tokenize
clip_model = ClipModel(encoder, tokenizer, device)
end1 = time.time()
print(f"loading CLIP time: {end1 - start:.2f}s")
print("loading SAM")
end2 = time.time()
# 1024
sam_image_size = sam_image_size
# <segment_anything.utils.transforms.ResizeLongestSide object at 0x7f4f09a1c670>
# resize_transform_sam.target_length=1024
resize_transform_sam = ResizeLongestSide(sam_image_size)
# (image_encoder): ImageEncoderViT
# (prompt_encoder): PromptEncoder
# (mask_decoder): MaskDecoder
sam_model = sam_model_registry[sam_model](checkpoint=sam_pretrained).to(device)
# <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x7f4f0a54e910>
sam_mask_generator = SamAutomaticMaskGenerator(
sam_model,
points_per_side=points_per_side,
points_per_batch=points_per_batch,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
stability_score_offset=stability_score_offset,
box_nms_thresh=box_nms_thresh,
crop_n_layers=crop_n_layers,
crop_nms_thresh=crop_nms_thresh,
crop_overlap_ratio=crop_overlap_ratio,
crop_n_points_downscale_factor=crop_n_points_downscale_factor,
point_grids=point_grids,
min_mask_region_area=min_mask_region_area,
output_mode=output_mode,
)
end3 = time.time()
print(f"loading SAM time: {end3 - end2:.2f}s")
print("loading PROMPT")
end4 = time.time()
# (0.485, 0.456, 0.406)
'''
tensor([[[123.6750]],
[[116.2800]],
[[103.5300]]], device='cuda:0')
'''
# (0.229, 0.224, 0.225)
'''
tensor([[[58.3950]],
[[57.1200]],
[[57.3750]]], device='cuda:0')
'''
pixel_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).view(-1, 1, 1).to(device) * 255.0
pixel_std = torch.tensor(IMAGENET_DEFAULT_STD).view(-1, 1, 1).to(device) * 255.0
fgvp = FGVP_ENSEMBLE(
color_line=color_line,
thickness=thickness,
color_mask=color_mask,
alpha=alpha,
clip_processing=clip_processing,
clip_image_size=clip_image_size,
resize_transform_clip=resize_transform_clip,
pixel_mean=pixel_mean,
pixel_std=pixel_std,
blur_std_dev=blur_std_dev,
mask_threshold=sam_model.mask_threshold,
contour_scale=contour_scale,
device=device,
)
end5 = time.time()
print(f"loading PROMPT time: {end5 - end4:.2f}s")
# load image
end6 = time.time()
image = cv2.imread(img_dir)
# (426, 640)
real_size = image.shape[:2]
end7 = time.time()
print(f"loading image time: {end7 - end6:.2f}s")
end8 = time.time()
args = None
'''
tensor([[338., 195., 353., 253.],
[ 0., 77., 29., 238.],
[ 84., 327., 187., 425.],
[ 0., 231., 76., 285.],
[ 0., 249., 133., 424.],
[ 0., 32., 27., 138.],
[492., 44., 532., 150.],
[413., 2., 434., 21.],
[ 0., 249., 66., 357.],
[591., 191., 639., 222.],
[487., 178., 524., 301.],
[ 48., 114., 86., 163.],
[489., 81., 512., 120.],
[144., 200., 189., 279.],
[538., 269., 639., 377.],
[ 93., 289., 120., 339.],
[106., 178., 157., 240.],
[380., 99., 401., 147.],
[398., 268., 484., 323.],
[375., 46., 416., 147.],
[111., 327., 182., 405.],
[491., 44., 532., 99.],
[417., 61., 453., 91.],
[527., 51., 564., 303.],
[383., 175., 524., 302.],
[607., 234., 639., 282.],
[ 73., 241., 174., 334.],
[375., 46., 417., 99.],
[534., 51., 564., 254.],
[ 69., 0., 147., 181.],
[411., 169., 479., 268.],
[602., 0., 639., 227.],
[235., 361., 327., 425.],
[194., 24., 278., 140.],
[294., 84., 332., 130.],
[208., 43., 225., 124.],
[289., 148., 341., 176.],
[ 46., 285., 96., 339.],
[538., 269., 639., 318.],
[412., 253., 475., 268.],
[135., 133., 211., 187.],
[158., 301., 230., 419.],
[ 0., 0., 76., 116.],
[162., 178., 192., 202.],
[489., 138., 542., 171.],
[412., 170., 478., 184.],
[412., 184., 478., 254.],
[398., 136., 489., 323.],
[240., 164., 276., 202.],
[574., 222., 639., 281.],
[207., 40., 264., 125.],
[ 0., 350., 35., 407.],
[207., 211., 269., 312.],
[ 84., 42., 138., 119.],
[373., 41., 385., 54.],
[292., 279., 307., 293.],
[401., 301., 423., 310.],
[262., 293., 402., 343.],
[188., 197., 292., 288.],
[ 69., 0., 207., 181.],
[ 87., 268., 103., 288.],
[134., 290., 155., 313.],
[271., 219., 342., 294.],
[453., 55., 497., 133.],
[ 0., 232., 118., 357.],
[260., 262., 275., 282.],
[ 48., 310., 67., 336.],
[527., 138., 538., 150.],
[175., 272., 277., 345.],
[496., 303., 506., 318.],
[530., 35., 547., 51.],
[287., 321., 302., 349.],
[484., 329., 526., 349.],
[300., 0., 310., 82.],
[489., 137., 511., 150.],
[375., 46., 453., 99.],
[ 91., 185., 116., 215.],
[413., 114., 427., 142.],
[457., 109., 463., 123.],
[290., 361., 596., 425.],
[363., 311., 639., 424.],
[ 0., 2., 101., 271.],
[265., 202., 346., 370.],
[217., 280., 639., 424.],
[423., 303., 462., 315.],
[ 0., 338., 133., 425.],
[ 26., 101., 102., 278.],
[ 84., 380., 187., 425.],
[453., 0., 469., 19.],
[262., 293., 401., 425.],
[149., 269., 327., 425.],
[ 69., 0., 618., 240.],
[ 85., 328., 596., 425.],
[348., 15., 576., 304.],
[255., 139., 277., 164.],
[ 69., 0., 388., 196.],
[ 0., 0., 76., 238.],
[470., 0., 618., 237.],
[255., 145., 270., 163.],
[457., 101., 463., 123.],
[150., 269., 239., 423.],
[ 91., 49., 132., 119.],
[537., 302., 547., 346.],
[107., 243., 144., 257.],
[618., 327., 635., 378.],
[496., 295., 507., 318.],
[213., 0., 226., 27.]])
'''
'''
tensor([[[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]],
[[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]],
[[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]],
...,
[[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]],
[[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]],
[[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]]])
'''
# torch.Size([107, 4])
# torch.Size([107, 1, 426, 640])
boxes, masks = get_masks(args, image)
end9 = time.time()
print(f"get masks time: {end9 - end8:.2f}s")
centers = torch.stack([boxes[:, 0::2].mean(1), boxes[:, 1::2].mean(1)], 1)
res = image.copy()
for box, mask in zip(boxes, masks):
# array([338, 195, 353, 253])
box = box.detach().cpu().numpy().astype(int)
# (426, 640)
mask = mask.squeeze().detach().cpu().numpy()
res = draw_box(res, box, (0, 255, 0))
res = draw_mask(res, mask, (0, 255, 0))
cv2.imwrite(os.path.join(out_dir, "candidates.jpg"), res)
text_inputs = [f"a photo of a {t}" for t in text]
# torch.Size([107, 3, 336, 336])
clip_inputs = torch.cat([fgvp(vp, image[:, :, ::-1], centers, boxes, masks) for vp in visual_prompt])
# os.makedirs(os.path.join(out_dir, 'fgvp'), exist_ok=True)
# for idx, clip_input in enumerate(clip_inputs):
# clip_input = clip_input * pixel_std + pixel_mean
# clip_input = clip_input.permute(1, 2, 0).cpu().numpy().astype(np.uint8)[..., ::-1]
# cv2.imwrite(os.path.join(out_dir, f"fgvp/{idx}.jpg"), clip_input)
end10 = time.time()
# torch.Size([107, 1])
'''
tensor([[0.1727],
[0.1602],
[0.1897],
[0.1495],
[0.1699],
[0.1674],
[0.1772],
[0.1818],
[0.1854],
[0.1613],
[0.1852],
[0.1753],
[0.1937],
[0.2128],
[0.1696],
[0.1676],
[0.1700],
[0.1527],
[0.1765],
[0.1923],
[0.1746],
[0.1829],
[0.2363],
[0.1586],
[0.1869],
[0.1709],
[0.1648],
[0.1886],
[0.1627],
[0.2219],
[0.2025],
[0.1833],
[0.1614],
[0.2544],
[0.2332],
[0.2015],
[0.2114],
[0.2126],
[0.1688],
[0.1761],
[0.1741],
[0.1748],
[0.1866],
[0.1820],
[0.1730],
[0.1859],
[0.2030],
[0.1993],
[0.1868],
[0.1561],
[0.2300],
[0.1808],
[0.1710],
[0.2117],
[0.1836],
[0.1490],
[0.1708],
[0.1792],
[0.1803],
[0.2128],
[0.1753],
[0.1703],
[0.1860],
[0.1849],
[0.1765],
[0.1646],
[0.1763],
[0.1794],
[0.1821],
[0.1802],
[0.1843],
[0.1666],
[0.1646],
[0.1971],
[0.1659],
[0.2128],
[0.1714],
[0.1729],
[0.1689],
[0.1970],
[0.1733],
[0.1980],
[0.1871],
[0.1863],
[0.1749],
[0.1672],
[0.1686],
[0.1951],
[0.1862],
[0.1798],
[0.1667],
[0.2312],
[0.2043],
[0.2250],
[0.1881],
[0.2422],
[0.1824],
[0.2076],
[0.1792],
[0.1775],
[0.1840],
[0.2229],
[0.1613],
[0.1648],
[0.1724],
[0.1715],
[0.1959]], device='cuda:0', dtype=torch.float16)
'''
logits_per_image = clip_model(clip_inputs, text_inputs)
end11 = time.time()
print(f"CLIP time: {end11 - end10:.2f}s")
# torch.Size([107, 1])
'''
tensor([[0.1727],
[0.1602],
[0.1897],
[0.1495],
[0.1699],
[0.1674],
[0.1772],
[0.1818],
[0.1854],
[0.1613],
[0.1852],
[0.1753],
[0.1937],
[0.2128],
[0.1696],
[0.1676],
[0.1700],
[0.1527],
[0.1765],
[0.1923],
[0.1746],
[0.1829],
[0.2363],
[0.1586],
[0.1869],
[0.1709],
[0.1648],
[0.1886],
[0.1627],
[0.2219],
[0.2025],
[0.1833],
[0.1614],
[0.2544],
[0.2332],
[0.2015],
[0.2114],
[0.2126],
[0.1688],
[0.1761],
[0.1741],
[0.1748],
[0.1866],
[0.1820],
[0.1730],
[0.1859],
[0.2030],
[0.1993],
[0.1868],
[0.1561],
[0.2300],
[0.1808],
[0.1710],
[0.2117],
[0.1836],
[0.1490],
[0.1708],
[0.1792],
[0.1803],
[0.2128],
[0.1753],
[0.1703],
[0.1860],
[0.1849],
[0.1765],
[0.1646],
[0.1763],
[0.1794],
[0.1821],
[0.1802],
[0.1843],
[0.1666],
[0.1646],
[0.1971],
[0.1659],
[0.2128],
[0.1714],
[0.1729],
[0.1689],
[0.1970],
[0.1733],
[0.1980],
[0.1871],
[0.1863],
[0.1749],
[0.1672],
[0.1686],
[0.1951],
[0.1862],
[0.1798],
[0.1667],
[0.2312],
[0.2043],
[0.2250],
[0.1881],
[0.2422],
[0.1824],
[0.2076],
[0.1792],
[0.1775],
[0.1840],
[0.2229],
[0.1613],
[0.1648],
[0.1724],
[0.1715],
[0.1959]], device='cuda:0', dtype=torch.float16)
'''
logits_per_image = logits_per_image.view(len(visual_prompt), len(boxes), len(text_inputs)).mean(0)
# scores, row_inds = logits_per_image.topk(1, dim=1)
# 1 107
N, M = logits_per_image.shape
'''
tensor([[0.8413],
[0.8521],
[0.8271],
[0.8613],
[0.8438],
[0.8457],
[0.8374],
[0.8340],
[0.8306],
[0.8511],
[0.8311],
[0.8394],
[0.8237],
[0.8081],
[0.8442],
[0.8457],
[0.8438],
[0.8584],
[0.8384],
[0.8252],
[0.8398],
[0.8330],
[0.7896],
[0.8535],
[0.8296],
[0.8428],
[0.8481],
[0.8281],
[0.8496],
[0.8008],
[0.8169],
[0.8325],
[0.8511],
[0.7754],
[0.7920],
[0.8174],
[0.8096],
[0.8086],
[0.8447],
[0.8384],
[0.8403],
[0.8398],
[0.8296],
[0.8335],
[0.8413],
[0.8306],
[0.8164],
[0.8193],
[0.8296],
[0.8555],
[0.7944],
[0.8345],
[0.8428],
[0.8091],
[0.8320],
[0.8613],
[0.8428],
[0.8359],
[0.8350],
[0.8081],
[0.8394],
[0.8433],
[0.8301],
[0.8311],
[0.8384],
[0.8481],
[0.8384],
[0.8359],
[0.8335],
[0.8350],
[0.8315],
[0.8467],
[0.8481],
[0.8213],
[0.8472],
[0.8081],
[0.8423],
[0.8413],
[0.8447],
[0.8213],
[0.8408],
[0.8203],
[0.8291],
[0.8301],
[0.8394],
[0.8462],
[0.8447],
[0.8228],
[0.8301],
[0.8354],
[0.8462],
[0.7935],
[0.8149],
[0.7983],
[0.8286],
[0.7847],
[0.8335],
[0.8125],
[0.8359],
[0.8374],
[0.8320],
[0.8003],
[0.8511],
[0.8481],
[0.8418],
[0.8423],
[0.8223]], device='cuda:0', dtype=torch.float16)
'''
# torch.Size([107, 1])
cost = torch.exp(-logits_per_image)
# array([33]) array([0])
row_inds, col_inds = linear_sum_assignment(cost.cpu().numpy())
res = image.copy()
print(boxes[row_inds])
# tensor([[194., 24., 278., 140.]])
for row_ind, col_ind in zip(row_inds, col_inds):
# array([194, 24, 278, 140])
box = boxes[row_ind].detach().cpu().numpy().astype(int)
# (426, 640)
mask = masks[row_ind].squeeze().detach().cpu().numpy()
# 'photo on the wall'
caption = text[col_ind]
res = draw_box(res, box, (0, 255, 0))
res = draw_mask(res, mask, (0, 255, 0))
res = draw_text(res, caption, box[0], box[1], (0, 255, 0))
cv2.imwrite(os.path.join(out_dir, "res_pre.jpg"), res)
ResizeLongestSide
这段代码实现了一个名为ResizeLongestSide
的类,用于调整图像、坐标和框的尺寸。它提供了处理NumPy数组和批处理的Torch张量的方法。
下面是对代码的逐行解释:
class ResizeLongestSide:
def __init__(self, target_length: int) -> None:
self.target_length = target_length
这个类的构造函数接受一个参数target_length
,表示目标尺寸的最长边长度。
def apply_image(self, image: np.ndarray) -> np.ndarray:
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return np.array(resize(to_pil_image(image), target_size))
apply_image
方法用于调整图像的尺寸。它接受一个NumPy数组表示的图像,并使用get_preprocess_shape
方法计算目标尺寸,然后使用resize
函数将图像调整为目标尺寸。
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
apply_coords
方法用于调整坐标的尺寸。它接受一个表示坐标的NumPy数组和原始图像的尺寸。根据原始图像和目标尺寸的比例,它将坐标进行缩放调整。
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
apply_boxes
方法用于调整框的尺寸。它接受一个表示框的NumPy数组和原始图像的尺寸。它首先将框的坐标形状调整为(-1, 2, 2),然后使用apply_coords
方法调整坐标的尺寸,最后将形状调整为(-1, 4)。
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True
)
apply_image_torch
方法用于处理批处理的Torch张量形式的图像。它使用get_preprocess_shape
方法计算目标尺寸,并使用F.interpolate
函数进行双线性插值,将图像调整为目标尺寸。
def apply_coords_torch(
self, coords: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
apply_coords_torch
方法用于处理批处理的Torch张量形式的坐标。它接受一个表示坐标的Torch张量和原始图像的尺寸。根据原始图像和目标尺寸的比例,它将坐标进行缩放调整。
def apply_boxes_torch(
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
apply_boxes_torch
方法用于处理批处理的Torch张量形式的框。它首先将框的坐标形状调整为(-1, 2, 2),然后使用apply_coords_torch
方法调整坐标的尺寸,最后将形状调整为(-1, 4)。
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
get_preprocess_shape
是一个静态方法,用于计算给定输入尺寸和目标最长边长度时的输出尺寸。它根据输入尺寸的长宽比例和目标最长边长度的比例,计算出新的高度和宽度。
这段代码实现了图像、坐标和框尺寸的调整功能,可用于在计算机视觉任务中进行预处理或后处理操作。它提供了处理NumPy数组和Torch张量的方法,方便在不同的数据类型上进行操作。