DETR的代码非常简洁,GitHub仓库,下面对DETR的每一部分进行详细解读
这一篇博文主要介绍参数、分布式初始化、随机种子和模型部分
参数部分
DETR的一开始是参数部分,这里只需要掌握argparse库即可,不做过多介绍
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
分布式初始化
这里我做了微小的修改,但是整体含义不变
分布式初始化的方法很简单,一般来说我们单机多卡进行训练,进入到第一个if语句,从环境变量中获取RANK、WORLD_SIZE和LOCAL_RANK这三个变量,再单机多卡中他们分别代表第几张卡,一共有几张卡,第几张卡,即RANK和LOCAL_RANK是一样的含义
随后设置当前分布式初始化的是第几张卡torch.cuda.set_device(args.gpu),并设置通信后端,默认的采取nccl进行通讯,然后就是init_process_group(backend, init_method, world_size, rank)
进行初始化了,
最后通过barrier函数来同步初始化,barrier简单理解为当所有卡都运行到这一步时,再继续运行下面的步骤,如果有卡没运行到这一步,则其余的卡会等待他,最后setup_for_distributed函数设置输出的卡,这里的意思是只有0卡的print函数会进行输出
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank==0) # 只有0卡进行输出
随机种子
从可复现的角度来考虑,所有的卡都要设置随机种子,并且对代码中可能出现的所有随机性函数赋值种子,这里考虑了torch,numpy以及random库中可能存在的随机初始化过程,来固定种子
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
模型初始化
这里我们主要讲解目标检测模型的初始化过程,DETR还可以进行实例分割与全景分割等任务,但是我们这里只讲解目标检测模型,后面会补充分割模型的设置
模型初始化函数为build_model
model, criterion, postprocessors = build_model(args)
对于coco检测任务,一共有91个类别,但是在实际模型中我们会看到92个类别,这是因为目标检测框如果内部无物体的话属于空,即多了背景类
DETR模型主要分为CNN的backbone和transformer,在模型中也分别初始化了这两点。下面对这两部分分别进行详细的解读
backbone = build_backbone(args)
transformer = build_transformer(args)
Backbone部分
backbone部分都是为transformer服务的,在backbone部分提供了位置编码和resnet特征提取器,位置编码采用的是attention is all you need中的常用位置编码,更详细的细节参考这篇文章
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)