首先看一下系统架构:
解读create_model方法
1.backbone = MobileNetV2(weights_path="./backbone/mobilenet_v2.pth").features-》加载MobileNetV2预训练模型
backbone.out_channels = 1280-》设置输出通道
2.anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),))-》调用AnchorsGenerator函数,AnchorsGenerator函数作用以后再讲,
3.roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], # 在哪些特征层上进行roi 池化
output_size=[7, 7], # roi_pooling输出特征矩阵尺寸
sampling_ratio=2) # 采样率
4. model = FasterRCNN(backbone=backbone,
num_classes=num_classes,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler)-》调用faster_rcnn_framework的FasterRCNN
解读main方法:
1.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")-》调用GPU
print(device)
2. if not os.path.exists("save_weights"):
os.makedirs("save_weights")-》 检查保存权重文件夹是否存在,不存在则创建
3. data_transform = {
"train": transforms.Compose([transforms.ToTensor(),
transforms.RandomHorizontalFlip(0.5)]),
"val": transforms.Compose([transforms.ToTensor()])
}-》加载数据预处理函数
4. VOC_root = "./"
assert os.path.exists(os.path.join(VOC_root, "VOCdevkit")), "not found VOCdevkit in path:'{}'".format(VOC_root)-》加载数据集
5.train_data_set = VOC2012DataSet(VOC_root, data_transform["train"], True)-》加载训练数据
train_data_loader = torch.utils.data.DataLoader(train_data_set,
batch_size=8,
shuffle=False,
num_workers=0,
collate_f