最近有一个想法,想使用Mask DINO这个模型,查了一下,他是基于Detectron2框架实现的,但是又需求对这个框架进行一些魔改,所以需要对这个框架的源码进行学习。
首先是跟着Detectron2的官方文档进行学习:官方文档
直接跳过安装等环节,首先看dataset。
Dataset
我们在torch中使用Dataset的时候,一般这么写:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
# 在这里对数据进行预处理、转换等操作
return sample
# 创建数据集
data = [1, 2, 3, 4, 5]
dataset = CustomDataset(data)
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
# 使用数据加载器遍历数据
for batch in dataloader:
print(batch)
我们一般在Dataset方法中完成对数据的读取,并且使用__getitem__
方法来根据索引取数据。
然而在Detectron2中,有所不同。
Detectron2中,首先对数据集进行注册,注册后,使用DatasetCatalog.get("my_dataset")
方法来获得数据集字典。注意,这时候仅仅是数据集字典,并没有读取数据集
数据集字典是一个由dict组合的List,List[Dict]
Detectron2 的标准数据集字典:
Task | Fields |
---|---|
Common | file_name, height, width, image_id |
Instance detection/segmentation | annotations |
Semantic segmentation | sem_seg_file_name |
Panoptic segmentation | pan_seg_file_name, segments_info |
具体可见:https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html
每个字典,包含数据集中一张图像的所有信息
Dataloader
在detectron2中,提供了两个函数,build_detection_{train,test}_loader
,然而,针对更高程度的自定义,我首先重写了这两个函数实现了我需要的新功能,并且自定义了mapper方法。
from detectron2.data import detection_utils as utils
# Show how to implement a minimal mapper, similar to the default DatasetMapper
def mapper(dataset_dict):
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
# can use other ways to read image
image = utils.read_image(dataset_dict["file_name"], format="BGR")
# See "Data Augmentation" tutorial for details usage
auginput = T.AugInput(image)
transform = T.Resize((800, 800))(auginput)
image = torch.from_numpy(auginput.image.transpose(2, 0, 1))
annos = [
utils.transform_instance_annotations(annotation, [transform], image.shape[1:])
for annotation in dataset_dict.pop("annotations")
]
return {
# create the format that the model expects
"image": image,
"instances": utils.annotations_to_instances(annos, image.shape[1:])
}
dataloader = build_detection_train_loader(cfg, mapper=mapper)
其中,读取数据并加载到内存中,对注释进行decoder,数据增强方法,在这一步中实现。
Model and Training loop
Detectron2的搭建模型和训练逻辑 关系比较密切。
Model
困难的是,在使用torch编写自己的模型的时候,通常只要继承nn.Module
,然后专注于forward
方法的构建。输入一般为图像的tensor,输出预测结果的tensor。
然而,Detectron2只提供了使用cfg方法搭建模型的接口,并且输入和输出有所不同。有时候我就只想使用我自己写的模型。
Detectron2的model通过outputs = model(inputs)
调用,其中,inputs是list[dict]
。这里,每个dict就是经过mapper后的dataloader的输出。
模型在运行中,分为training和eval模式,其中,training模式的输出是loss构成的字典,而eval模式输出字典组成的列表,其中字典的内容就是模型的预测结果。
详细内容可见文档。
Training loop
训练的循环在 tools文件夹中的train_net.py
文件中有介绍
训练的逻辑通过Trainer类来管理,提供了build_evaluator,build_train_loader,build_lr_scheduler三个方法。分别用来评估模型精度、构建训练逻辑、构建学习率调度器
def main(args):
cfg = setup(args)
if args.eval_only:
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
res = Trainer.test(cfg, model)
return res
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
在main方法中,实现模型的训练或推理过程
def invoke_main() -> None:
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
通过invok_main方法来实现模型的分布式训练。
然而,这里的Trainer类继承于DefaultTrainer类,我们实际使用的时候,更多的可能继承于SimpleTrainer。说到这,就需要深入研究一下detectron2中engine的部分。
engine由defaults.py
hooks.py
launch.py
train_loop.py
组成。
train_loop.py
train_loop.py
由四个类组成,分别是HookBase
, TrainBase
, SimpleTrainer
, AMPTrainer
.
- HookBase是一个基类,用于实现hook机制。提供了before_train, after_train, before_step, after_backward, after_step, state_dict的接口,但是没有具体实现。
- TrainBase也是一个基类,定义的方法分为3类。(1)注册hook机制。(2)遍历hook list并执行。(3)实现train方法。
def train(self, start_iter: int, max_iter: int):
"""
Args:
start_iter, max_iter (int): See docs above
"""
logger = logging.getLogger(__name__)
logger.info("Starting training from iteration {}".format(start_iter))
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
# self.iter == max_iter can be used by `after_train` to
# tell whether the training successfully finished or failed
# due to exceptions.
self.iter += 1
except Exception:
logger.exception("Exception during training:")
raise
finally:
self.after_train()
从train方法的具体实现过程看,传入的参数分别是start_iter: int, max_iter: int。通过EventStorage(start_iter)
来存储训练中需要记录的信息,通过run_step()
方法实现训练过程。但是在TrainBase中,run_step()
方法没有实现哦。
- SimpleTrainer是TrainBase的子类
它假定在每一步中,需要:(1)使用数据加载器中的数据计算损失。(2)使用上述损失计算梯度。(3)使用优化器更新模型。
训练期间的所有其它任务(检查点、记录日志、评估、学习率调度)由Hook来管理。
如果想要做比这更复杂的任务,可以要么继承 TrainerBase 并实现自己的 run_step,要么编写自己的训练循环。
SimpleTrainer的init方法包括:
def __init__(
self,
model,
data_loader,
optimizer,
gather_metric_period=1,
zero_grad_before_forward=False,
async_write_metrics=False,
):
- AMPTrainer这里我们不做讲解。