前几天老板图像想试试fsod,以前根据论文复现了一下,不过当时只有论文,本渣渣复现不理想,无奈放弃,现在他突然发现不仅开源了,还有了好几版,然后找来这个让我试一下,用的是这个git:https://github.com/fanq15/FewX
不过显示发现,这个代码只有训练和验证部分,没有测试代码,然后他用的是detectron2框架,说实话,这个框架,和之前的一篇博文写的mmdetection差不多,就是用起来可能很方便,但是改起来很痛苦,很烦,而且当时写那个的时候,感觉把mmdetection摸得挺透的,很多想写的,但是懒,关键是现在忘得差不多了,现在通过写这个代码来重新记录一下。
首先和之前mmdetection的环境一样,个人感觉最简单的装法就是根据他提供的dockerfile来推测他需要什么环境,复制粘贴的事情,好像比mmdetection还要简单些,但是我当时装的时候两者的mmcv版本不一样,大概是不能共存的,要不就用docker吧,我这边是直接覆盖了,以后要是mmdet报错就再解决吧,最后不要再用这些框架了,真的很烦。
首先来讲讲他这个fewx的代码要怎么看(这里写的比较简略,也就不放图了),他的readme说了要传递config的路径,比如“finetune_R_50_C4_1x.yaml”这个yaml文件,就可以看到他的一些基本结构,基于“_BASE_”这个节点的yaml文件,再过去看这个文件,这里简单说明一点,这里面,所有节点的参数,有的就是传递具体的值,比如数字还有一些bool类型,有的则是表示这部分需要调用某个类的类名,比如下面这部分yaml
MODEL:
META_ARCHITECTURE: "FsodRCNN"
PROPOSAL_GENERATOR:
NAME: "FsodRPN"
RPN:
PRE_NMS_TOPK_TEST: 6000
POST_NMS_TOPK_TEST: 100
ROI_HEADS:
NAME: "FsodRes5ROIHeads"
BATCH_SIZE_PER_IMAGE: 128
POSITIVE_FRACTION: 0.5
NUM_CLASSES: 1
BACKBONE:
FREEZE_AT: 3
#PIXEL_MEAN: [102.9801, 115.9465, 122.7717]
然后这边“FsodRCNN”和“FsodRPN”都是类名,然后这些类在什么地方呢,不知道是不是可以根据节点名来推测位置,反正我的方法就是找找各个文件夹里面的__init__.py,比如在很容易就在“fewx/modeling/__init__.py”中找到下面内容
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from .fsod import FsodRCNN, FsodRes5ROIHeads, FsodFastRCNNOutputLayers, FsodRPN
_EXCLUDE = {"torch", "ShapeSpec"}
__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")]
第一行就是我们需要的类,其实到这里,差不多就可以理解detectron2的基本调用方式,他就是通过这些配置文件以及各种__init__.py,再通过字符串来进行调用的,明白这个方式,detectron2也就没什么好说的,如果自己有需要,可以找一个yaml来进行修改,调用自己的类来进行操作就可以了。
顺便说一下,他的fsod_rcnn.py里面写的挺清晰的,有需要可以看一下,我这边做了一下他的验证部分,我只有一块卡,改了以下all.sh
python3 fsod_train_net.py --num-gpus 1 \
--config-file configs/fsod/finetune_R_50_C4_1x.yaml \
--eval-only MODEL.WEIGHTS ./output/fsod/finetune_dir/R_50_C4_1x/model_final.pth 2>&1 | tee log/fsod_finetune_test_log.txt
就只执行这一段就可以了,想要研究他的代码走向的,可以顺这个来一步一步加打印来查看代码具体怎么走的,他会生成一系列中间文件,而fsod_rcnn.py内有读取方法,仔细看一下,就很容易明白这些文件存的是啥,要怎么读出来,不过他用的是pandas,我就直接用pickle读出来直接读的。对了他的数据集有进行一些预处理,不赘述了。
由于我只需要写一个预测代码就可以了,所以 不需要读取gt,也就不太需要coco的相关api。其实我一直都想脱离detectron2的框架来自己加载模型使用,这样在以后落地或者是纯预测的时候不需要安装这个繁琐的框架了,就是加载模型,读取图片,获取结果就完了,不过没有这个具体做过,以下还是使用这个框架。
由于不想修改作者的源代码,所以先复制了他的FsodRCNN类,改名为FsodWordRCNN,注意@META_ARCH_REGISTRY.register()这个装饰器不能掉,否则会找不到这个类,所以__init__.py可能不是必须的,这个装饰器注册过之后才会被框架找到,但是如果类名不改的话会提示重名,然后将yaml里面的类名改掉就可以了,此时再执行相同的代码,会进入我的自定义的类,而我只需要实现这个类的预测部分,init部分照写,forward就是调用inference函数,如下
@META_ARCH_REGISTRY.register()
class FsodWordRCNN(nn.Module):
"""
Generalized R-CNN. Any models that contains the following three components:
1. Per-image feature extraction (aka backbone)
2. Region proposal generation
3. Per-region feature extraction and prediction
"""
def __init__(self, cfg):
super().__init__()
self.backbone = build_backbone(cfg)
self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape())
# self.proposal_generator.training = False
self.roi_heads = build_roi_heads(cfg, self.backbone.output_shape())
self.vis_period = cfg.VIS_PERIOD
self.input_format = cfg.INPUT.FORMAT
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1))
self.in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
self.support_way = cfg.INPUT.FS.SUPPORT_WAY
self.support_shot = cfg.INPUT.FS.SUPPORT_SHOT
# self.logger = logging.getLogger(__name__)
def forward(self, batched_inputs):
# self.init_model()
return self.inference(batched_inputs)
我没有太研究init中的东西是不是都有用到,这边可以先测试一下,是不是会调用到inference函数,不过暂时先不放出测试代码(因为没有了)
以下的init的代码是加载模型的代码,我这边的detectron2的版本是0.2
def init(configfile="configs/finetune_R_50_C4_1x.yaml"
, modelfile="./output/fsod/finetune_dir/R_50_C4_1x/model_final.pth"
, classnum=4):
cfg = get_cfg()
cfg.merge_from_file(configfile)
cfg.MODEL.WEIGHTS = modelfile
predictor = DefaultPredictor(cfg)
DetectionCheckpointer(predictor.model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=False
)
return predictor.model
这边的加载模型的方法和源代码不太一样,不过改一改,用他的也行,只是我喜欢这样加载而已,具体怎么操作的,或者是有什么区别,可以查看detectron2的代码,这边注意返回的是predictor的model,只返回predictor的话,将其当model执行的时候,会调用detectron2的代码,其与我希望的执行方式不一样。
下面是读取图片的代码
def loadData():
from glob import glob
from torchvision.transforms import ToTensor
import json
img_root = './data_gen'
with open(f"{img_root}/support.json", 'r') as ff:
support_img_datas = json.load(ff)
o_query_img = utils.read_image(f"{img_root}/img_000000011.jpg", format="BGR")
h, w = o_query_img.shape[:2]
aug_input = T.StandardAugInput(o_query_img, sem_seg=None)
transforms = aug_input.apply_augmentations([ResizeShortestEdge(short_edge_length=(600, 600), max_size=1000, sample_style='choice')])
query_img = aug_input.image
query_img = torch.as_tensor(np.ascontiguousarray(query_img.transpose(2, 0, 1)))
support_imgs = []
support_boxes = []
for item in support_img_datas:
for k, support_img_data in item.items():
imgPath, imgBox = support_img_data
img = utils.read_image(f"{img_root}/{imgPath}", format='BGR')
img = torch.as_tensor(np.ascontiguousarray(img.transpose(2, 0, 1)))
support_imgs.append(img)
support_boxes.append(Boxes([imgBox]))
return o_query_img, query_img, (h,w), support_imgs, support_boxes
代码中的相关路径是我这边的路径,这里需要注意的是,我原本是想通过cv2来读图片的,虽然可以最终可以执行成功,但是却没有任何效果,改为他的读取方式(utils.read_image)就可以了,虽然看起来他也是读取为BGR的模式,懒得追究了,fsod的原理就是,给定一张大图,然后n张同以一类别support图片,他会在大图中找到该类别的目标,所以输入需要一张大图和n张support,这里query是指的大图,由于后面要画图,所以读出来的原图我保存并且返回了,最麻烦的是,他在support中进行一些预处理操作,翻了半天他的源码,才找到这些,真的是很伤神
下面是inference的代码
def inference(self, batched_inputs, detected_instances=None, do_postprocess=True):
print("inference----begin!")
query_img, q_size, support_imgs, support_boxes = batched_inputs
support_imgs = [(x - self.pixel_mean) / self.pixel_std for x in support_imgs]
support_imgs = ImageList.from_tensors(support_imgs, self.backbone.size_divisibility)
support_features = self.backbone(support_imgs.tensor)
# print(support_features.keys())
print(support_features['res4'].shape)
res4_pooled = self.roi_heads.roi_pooling(support_features, support_boxes)
res4_avg = res4_pooled.mean(0, True)
res4_avg = res4_avg.mean(dim=[2,3], keepdim=True)
# support_dict['res4_avg'][cls] = res4_avg.detach().cpu().data
print(res4_avg.shape)
res5_feature = self.roi_heads._shared_roi_transform([support_features[f] for f in self.in_features], support_boxes)
res5_avg = res5_feature.mean(0, True)
# support_dict['res5_avg'][cls] = res5_avg.detach().cpu().data
print(res5_avg.shape)
query_img = (query_img - self.pixel_mean) / self.pixel_std
images = ImageList.from_tensors([query_img], self.backbone.size_divisibility)
print(images.tensor.shape)
features = self.backbone(images.tensor)
support_proposals_dict = {}
support_box_features_dict = {}
# for cls_id in range(20):
cls_id = 0
query_images = ImageList.from_tensors([images[0]])
query_features_res4 = features['res4'] # one query feature for attention rpn
query_features = {'res4': query_features_res4}
support_box_features = res5_avg#.detach().cpu().data
correlation = F.conv2d(query_features_res4, res4_avg.permute(1,0,2,3), groups=1024)
support_correlation = {'res4': correlation} # attention map for attention rpn
proposals, _ = self.proposal_generator(query_images, support_correlation, None)
print('proposals----------------->', len(proposals[0]))
support_proposals_dict[cls_id] = proposals
support_box_features_dict[cls_id] = support_box_features
results, _ = self.roi_heads.eval_with_support(query_images, query_features, support_proposals_dict, support_box_features_dict)
# print(results)
return FsodWordRCNN._postprocess(results, [{'height':q_size[0], 'width':q_size[1]}], images.image_sizes)
@staticmethod
def _postprocess(instances, batched_inputs, image_sizes):
"""
Rescale the output instances to the target size.
"""
# note: private function; subject to changes
processed_results = []
for results_per_image, input_per_image, image_size in zip(
instances, batched_inputs, image_sizes
):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width)
processed_results.append({"instances": r})
return processed_results
这里主要是参考他的代码写的,只是去掉了他的一个循环,然后这里必须要改一下他的代码,在“fewx/modeling/fsod/fsod_roi_heads.py”中,有一行是“assert len(proposal_boxes[0]) == 2000”,将这个注释,否则会报错,我这里这个长度达不到2000,可能他是相对coco的每个类来设定的,但是我这里就一张图片加预测一个类别,就那么多
到此,这个就可以进行预测了,完整代码如下
import torch
from torch import nn
from torch.nn import functional as F
from fewx.config import get_cfg
from fewx.modeling.fsod.fsod_roi_heads import build_roi_heads
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.modeling.proposal_generator import build_proposal_generator
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.structures import ImageList, Boxes, Instances
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from detectron2.data.transforms.augmentation_impl import ResizeShortestEdge
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import DefaultPredictor#, DefaultTrainer
__all__ = ["FsodRCNN"]
@META_ARCH_REGISTRY.register()
class FsodWordRCNN(nn.Module):
"""
Generalized R-CNN. Any models that contains the following three components:
1. Per-image feature extraction (aka backbone)
2. Region proposal generation
3. Per-region feature extraction and prediction
"""
def __init__(self, cfg):
super().__init__()
self.backbone = build_backbone(cfg)
self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape())
self.roi_heads = build_roi_heads(cfg, self.backbone.output_shape())
self.vis_period = cfg.VIS_PERIOD
self.input_format = cfg.INPUT.FORMAT
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1))
self.in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES
self.support_way = cfg.INPUT.FS.SUPPORT_WAY
self.support_shot = cfg.INPUT.FS.SUPPORT_SHOT
def forward(self, batched_inputs):
return self.inference(batched_inputs)
def inference(self, batched_inputs, detected_instances=None, do_postprocess=True):
print("inference----begin!")
query_img, q_size, support_imgs, support_boxes = batched_inputs
support_imgs = [(x - self.pixel_mean) / self.pixel_std for x in support_imgs]
support_imgs = ImageList.from_tensors(support_imgs, self.backbone.size_divisibility)
support_features = self.backbone(support_imgs.tensor)
res4_pooled = self.roi_heads.roi_pooling(support_features, support_boxes)
res4_avg = res4_pooled.mean(0, True)
res4_avg = res4_avg.mean(dim=[2,3], keepdim=True)
res5_feature = self.roi_heads._shared_roi_transform([support_features[f] for f in self.in_features], support_boxes)
res5_avg = res5_feature.mean(0, True)
query_img = (query_img - self.pixel_mean) / self.pixel_std
images = ImageList.from_tensors([query_img], self.backbone.size_divisibility)
features = self.backbone(images.tensor)
support_proposals_dict = {}
support_box_features_dict = {}
cls_id = 0
query_images = ImageList.from_tensors([images[0]])
query_features_res4 = features['res4'] # one query feature for attention rpn
query_features = {'res4': query_features_res4}
support_box_features = res5_avg#.detach().cpu().data
correlation = F.conv2d(query_features_res4, res4_avg.permute(1,0,2,3), groups=1024)
support_correlation = {'res4': correlation} # attention map for attention rpn
proposals, _ = self.proposal_generator(query_images, support_correlation, None)
support_proposals_dict[cls_id] = proposals
support_box_features_dict[cls_id] = support_box_features
results, _ = self.roi_heads.eval_with_support(query_images, query_features, support_proposals_dict, support_box_features_dict)
return FsodWordRCNN._postprocess(results, [{'height':q_size[0], 'width':q_size[1]}], images.image_sizes)
@staticmethod
def _postprocess(instances, batched_inputs, image_sizes):
"""
Rescale the output instances to the target size.
"""
# note: private function; subject to changes
processed_results = []
for results_per_image, input_per_image, image_size in zip(
instances, batched_inputs, image_sizes
):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width)
processed_results.append({"instances": r})
return processed_results
def init(configfile="configs/finetune_R_50_C4_1x.yaml"
, modelfile="./output/fsod/finetune_dir/R_50_C4_1x/model_final.pth"
, classnum=4):
cfg = get_cfg()
cfg.merge_from_file(configfile)
cfg.MODEL.WEIGHTS = modelfile
predictor = DefaultPredictor(cfg)
DetectionCheckpointer(predictor.model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=False
)
return predictor.model
def loadData():
from glob import glob
from torchvision.transforms import ToTensor
import json
img_root = './data_gen'
with open(f"{img_root}/support.json", 'r') as ff:
support_img_datas = json.load(ff)
o_query_img = utils.read_image(f"{img_root}/img_000000011.jpg", format="BGR")
h, w = o_query_img.shape[:2]
aug_input = T.StandardAugInput(o_query_img, sem_seg=None)
transforms = aug_input.apply_augmentations([ResizeShortestEdge(short_edge_length=(600, 600), max_size=1000, sample_style='choice')])
query_img = aug_input.image
query_img = torch.as_tensor(np.ascontiguousarray(query_img.transpose(2, 0, 1)))
support_imgs = []
support_boxes = []
for item in support_img_datas:
for k, support_img_data in item.items():
imgPath, imgBox = support_img_data
img = utils.read_image(f"{img_root}/{imgPath}", format='BGR')
img = torch.as_tensor(np.ascontiguousarray(img.transpose(2, 0, 1)))
support_imgs.append(img)
support_boxes.append(Boxes([imgBox]))
return o_query_img, query_img, (h,w), support_imgs, support_boxes
dataset_dicts = get_detection_dataset_dicts(
[dataset_name],
filter_empty=False, # True,
proposal_files=[
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]
]
if cfg.MODEL.LOAD_PROPOSALS
else None,
)
if __name__ == "__main__":
import cv2
import numpy as np
device = torch.device('cuda')
model = init()
o_query_img, query_img, (h,w), support_imgs, support_boxes = loadData()
query_img = query_img.to(device)
support_imgs = [i.to(device) for i in support_imgs]
support_boxes = [i.to(device) for i in support_boxes]
res = model([query_img, (h,w), support_imgs, support_boxes])
print(o_query_img.shape, type(o_query_img))
o_query_img = o_query_img.astype(np.uint8)
for item in res:
instances = item['instances']
fields = instances.get_fields()
pred_boxes = fields['pred_boxes']
for bbox in pred_boxes:
x, y, w, h = bbox.cpu().numpy().astype(int).tolist()
cv2.rectangle(o_query_img, (x,y), (w,h), (255,0,0), 2)
cv2.imwrite("test.jpg", o_query_img)
可能有在删除注释的时候删除了些其他东西,将就着看吧,欢迎讨论