近transform在cv圈可谓hot到爆,记录一下之前参加比赛所用的代码swin-transform,算法来源于ICCV2020的best paper。废话少说,直接上源码 code :https://github.com/SwinTransformer/Swin-Transformer-Object-Detection paper :https://openaccess.thecvf.com/content/ICCV2021/papers/Liu_Swin_Transformer_Hierarchical_Vision_Transformer_Using_Shifted_Windows_ICCV_2021_paper.pdf
搭建swin-transform需要的环境
1 .conda create -n swin python = 3.7
2 .conda activate swin
3 .conda install pytorch == 1.7 .0 torchvision == 0.8 .0 torchaudio == 0.7 .0 cudatoolkit = 11.0
4 .pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html
5 .pip install mmdet
本教程所用的mmdetection版本:code链接 https://github.com/open-mmlab/mmdetection/releases/tag/v2.18.0 1.执行mmdetection下的setup.py编译环境 至此环境的事基本上弄完了 2.由于我们目标只有一类,需要修改mmdetction下相关文件,新建一个data文件存放数据集,数据集按coco数据集格式准备。
找到configs文件下_base_文件、打开_base_文件下dataset文件里的coco_detection.py 对应code
train_pipeline = [
dict ( type = 'LoadImageFromFile' ) ,
dict ( type = 'LoadAnnotations' , with_bbox= True ) ,
dict ( type = 'Resize' , img_scale= ( 1333 , 800 ) , keep_ratio= True ) ,
dict ( type = 'RandomFlip' , flip_ratio= 0.5 ) ,
dict ( type = 'Normalize' , ** img_norm_cfg) ,
dict ( type = 'Pad' , size_divisor= 32 ) ,
dict ( type = 'DefaultFormatBundle' ) ,
dict ( type = 'Collect' , keys= [ 'img' , 'gt_bboxes' , 'gt_labels' ] ) ,
]
test_pipeline = [
dict ( type = 'LoadImageFromFile' ) ,
dict (
type = 'MultiScaleFlipAug' ,
img_scale= ( 1333 , 800 ) ,
flip= False ,
transforms= [
dict ( type = 'Resize' , keep_ratio= True ) ,
dict ( type = 'RandomFlip' ) ,
dict ( type = 'Normalize' , ** img_norm_cfg) ,
dict ( type = 'Pad' , size_divisor= 32 ) ,
dict ( type = 'ImageToTensor' , keys= [ 'img' ] ) ,
dict ( type = 'Collect' , keys= [ 'img' ] ) ,
] )
]
找到configs文件下_base_文件、打开_base_文件下models如图所示:本文选用mask_rcnn_r50_fpn.py作为backbone。修改mask_rcnn_r50_fpn.py里的num_class
对应code
model = dict (
type = 'MaskRCNN' ,
backbone= dict (
type = 'ResNet' ,
depth= 50 ,
num_stages= 4 ,
out_indices= ( 0 , 1 , 2 , 3 ) ,
frozen_stages= 1 ,
norm_cfg= dict ( type = 'BN' , requires_grad= True ) ,
norm_eval= True ,
style= 'pytorch' ,
init_cfg= dict ( type = 'Pretrained' , checkpoint= 'torchvision://resnet50' ) ) ,
neck= dict (
type = 'FPN' ,
in_channels= [ 256 , 512 , 1024 , 2048 ] ,
out_channels= 256 ,
num_outs= 5 ) ,
rpn_head= dict (
type = 'RPNHead' ,
in_channels= 256 ,
feat_channels= 256 ,
anchor_generator= dict (
type = 'AnchorGenerator' ,
scales= [ 8 ] ,
ratios= [ 0.5 , 1.0 , 2.0 ] ,
strides= [ 4 , 8 , 16 , 32 , 64 ] ) ,
bbox_coder= dict (
type = 'DeltaXYWHBBoxCoder' ,
target_means= [ .0 , .0 , .0 , .0 ] ,
target_stds= [ 1.0 , 1.0 , 1.0 , 1.0 ] ) ,
loss_cls= dict (
type = 'CrossEntropyLoss' , use_sigmoid= True , loss_weight= 1.0 ) ,
loss_bbox= dict ( type = 'L1Loss' , loss_weight= 1.0 ) ) ,
roi_head= dict (
type = 'StandardRoIHead' ,
bbox_roi_extractor= dict (
type = 'SingleRoIExtractor' ,
roi_layer= dict ( type = 'RoIAlign' , output_size= 7 , sampling_ratio= 0 ) ,
out_channels= 256 ,
featmap_strides= [ 4 , 8 , 16 , 32 ] ) ,
bbox_head= dict (
type = 'Shared2FCBBoxHead' ,
in_channels= 256 ,
fc_out_channels= 1024 ,
roi_feat_size= 7 ,
num_classes= 80 ,
bbox_coder= dict (
type = 'DeltaXYWHBBoxCoder' ,
target_means= [ 0. , 0. , 0. , 0. ] ,
target_stds= [ 0.1 , 0.1 , 0.2 , 0.2 ] ) ,
reg_class_agnostic= False ,
loss_cls= dict (
type = 'CrossEntropyLoss' , use_sigmoid= False , loss_weight= 1.0 ) ,
loss_bbox= dict ( type = 'L1Loss' , loss_weight= 1.0 ) ) ,
mask_roi_extractor= dict (
type = 'SingleRoIExtractor' ,
roi_layer= dict ( type = 'RoIAlign' , output_size= 14 , sampling_ratio= 0 ) ,
out_channels= 256 ,
featmap_strides= [ 4 , 8 , 16 , 32 ] ) ,
mask_head= dict (
type = 'FCNMaskHead' ,
num_convs= 4 ,
in_channels= 256 ,
conv_out_channels= 256 ,
num_classes= 80 ,
loss_mask= dict (
type = 'CrossEntropyLoss' , use_mask= True , loss_weight= 1.0 ) ) ) ,
修改mmdet
找到mmdet下core里的evaluation文件。打开找到class_names.py 修改数据集的label
def coco_classes ( ) :
return [
'person' , 'bicycle' , 'car' , 'motorcycle' , 'airplane' , 'bus' , 'train' ,
'truck' , 'boat' , 'traffic_light' , 'fire_hydrant' , 'stop_sign' ,
'parking_meter' , 'bench' , 'bird' , 'cat' , 'dog' , 'horse' , 'sheep' ,
'cow' , 'elephant' , 'bear' , 'zebra' , 'giraffe' , 'backpack' , 'umbrella' ,
'handbag' , 'tie' , 'suitcase' , 'frisbee' , 'skis' , 'snowboard' ,
'sports_ball' , 'kite' , 'baseball_bat' , 'baseball_glove' , 'skateboard' ,
'surfboard' , 'tennis_racket' , 'bottle' , 'wine_glass' , 'cup' , 'fork' ,
'knife' , 'spoon' , 'bowl' , 'banana' , 'apple' , 'sandwich' , 'orange' ,
'broccoli' , 'carrot' , 'hot_dog' , 'pizza' , 'donut' , 'cake' , 'chair' ,
'couch' , 'potted_plant' , 'bed' , 'dining_table' , 'toilet' , 'tv' ,
'laptop' , 'mouse' , 'remote' , 'keyboard' , 'cell_phone' , 'microwave' ,
'oven' , 'toaster' , 'sink' , 'refrigerator' , 'book' , 'clock' , 'vase' ,
'scissors' , 'teddy_bear' , 'hair_drier' , 'toothbrush'
]
def coco_classes ( ) :
return [ '你自己的label' , ]
找到mmdet下datasets下的coco.py,如图
class CocoDataset ( CustomDataset) :
CLASSES = ( 'person' , 'bicycle' , 'car' , 'motorcycle' , 'airplane' , 'bus' ,
'train' , 'truck' , 'boat' , 'traffic light' , 'fire hydrant' ,
'stop sign' , 'parking meter' , 'bench' , 'bird' , 'cat' , 'dog' ,
'horse' , 'sheep' , 'cow' , 'elephant' , 'bear' , 'zebra' , 'giraffe' ,
'backpack' , 'umbrella' , 'handbag' , 'tie' , 'suitcase' , 'frisbee' ,
'skis' , 'snowboard' , 'sports ball' , 'kite' , 'baseball bat' ,
'baseball glove' , 'skateboard' , 'surfboard' , 'tennis racket' ,
'bottle' , 'wine glass' , 'cup' , 'fork' , 'knife' , 'spoon' , 'bowl' ,
'banana' , 'apple' , 'sandwich' , 'orange' , 'broccoli' , 'carrot' ,
'hot dog' , 'pizza' , 'donut' , 'cake' , 'chair' , 'couch' ,
'potted plant' , 'bed' , 'dining table' , 'toilet' , 'tv' , 'laptop' ,
'mouse' , 'remote' , 'keyboard' , 'cell phone' , 'microwave' ,
'oven' , 'toaster' , 'sink' , 'refrigerator' , 'book' , 'clock' ,
'vase' , 'scissors' , 'teddy bear' , 'hair drier' , 'toothbrush' )
开始训练
python tools/ train. py configs/ swin/ mask_rcnn_swin- t- p4- w7_fpn_fp16_ms- crop- 3x_coco. py
跑起来了,开始训练后,训练日志在mmdetection下work_dir里。
调用训练好的模型测试
import asyncio
from argparse import ArgumentParser
from mmdet. apis import ( async_inference_detector, inference_detector,
init_detector, show_result_pyplot)
def parse_args ( ) :
parser = ArgumentParser( )
parser. add_argument( 'img' , default= 'image root' , help = 'Image file' )
parser. add_argument( 'config' , default= 'config root' , help = 'Config file' )
parser. add_argument( 'checkpoint' , default= 'save model root' , help = 'Checkpoint file' )
parser. add_argument(
'--device' , default= 'cuda:0' , help = 'Device used for inference' )
parser. add_argument(
'--score-thr' , type = float , default= 0.3 , help = 'bbox score threshold' )
parser. add_argument(
'--async-test' ,
action= 'store_true' ,
help = 'whether to set async options for async inference.' )
args = parser. parse_args( )
return args
def main ( args) :
model = init_detector( args. config, args. checkpoint, device= args. device)
result = inference_detector( model, args. img)
show_result_pyplot( model, args. img, result, score_thr= args. score_thr)
async def async_main ( args) :
model = init_detector( args. config, args. checkpoint, device= args. device)
tasks = asyncio. create_task( async_inference_detector( model, args. img) )
result = await asyncio. gather( tasks)
show_result_pyplot( model, args. img, result[ 0 ] , score_thr= args. score_thr)
if __name__ == '__main__' :
args = parse_args( )
if args. async_test:
asyncio. run( async_main( args) )
else :
main( args)