deeplabv3+复现

deeplabv3+复现实验

1. 数据集

1.1 下载整理

参考:https://blog.csdn.net/github_36923418/article/details/81453519
deeplabv3+中说明,文章使用的数据集是VOC2012,并通过SBD的方法进行了数据增强

VOC2012数据集包括4大类,20小类,算背景共21类,文件结构如下:

  • Annotations:xml文件,对每个目标进行描述
  • ImageSets:标注图片列表
  • JPEGImages:所有的原图,共17125张
  • SegmentationClass:分割标注数据,按类别标注,共2913张
  • SegmentationObject:同上,但同类别不同个体用不同颜色标注

一般分割使用SegmentationClass数据集中的png图片作为label,与其相对应的原图作为image

VOC2012数据集下载地址:
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

因为VOC数据集样本有限,一般使用基于VOC的增强数据集SBD进行实验,论文《Semantic Contours from Inverse Detectors》构建了增强数据集SBD:
http://home.bharathh.info/pubs/codes/SBD/download.html

下载后发现数量和文中提及的10582不符,查阅资料后发现10582是(VOC、SBD的验证、训练)-(VOC验证)得到的,但还有更简单的方法,这个网址:
https://gist.githubusercontent.com/sun11/2dbda6b31acc7c6292d14a872d0c90b7/raw/5f5a5270089239ef2f6b65b1cc55208355b5acca/trainaug.txt
提供了10582张图片的名称,对应使用VOC2012中的图片即可
labels在(需科学上网,可以在github上找):https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0下载
里面的标注是灰度图,所以肉眼看起来区别不大

验证测试集都是VOC2012中原本的,复制过来就好

下载好后将标注和图片按txt名字查找放到文件夹下:

# 下面分别为标签图片的输入、输出路径,按自己的路径更正
txt_path = r"txt"
load_lab_path = r'SegmentationClassAug'
load_img_path = r"JPEGImages"
save_img_path = r"image"
save_lab_path = r"label"
with open(txt_path, "r") as f:
    line = f.readline()
    while line:
        pic_name = line.rstrip("\n")
        shutil.copyfile(os.path.join(load_img_path, pic_name + '.jpg'), os.path.join(save_img_path, pic_name + '.jpg'))
        shutil.copyfile(os.path.join(load_lab_path, pic_name + '.png'), os.path.join(save_lab_path, pic_name + '.png'))
        line = f.readline()

1.2 构建dataset

    def __init__(self, data_path):
        self.data_path = data_path
        self.img_path = glob.glob(os.path.join(data_path, 'image/*.jpg'))

    # getitem和len为继承Dataset时必须的方法覆盖
    def __getitem__(self, index):
        image_path = self.img_path[index]
        label_path = image_path.replace('image', 'label').replace('jpg', 'png')
        # io.imread得到的是RGB,而cv2.imread得到的是BGR
        # image = skimage.io.imread(image_path)
        # label = skimage.io.imread(label_path)
        image = cv2.imread(image_path)
        label = cv2.imread(label_path)
        for i, row in enumerate(label):
            for j, line in enumerate(row):
                bgr = line.tolist()
                if bgr == [255, 255, 255]:
                    label[i][j] = np.array([0, 0, 0])
        image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_AREA)
        label = cv2.resize(label, (512, 512), interpolation=cv2.INTER_AREA)
        image = image.reshape(3, image.shape[0], image.shape[1])
        label = label.reshape(3, label.shape[0], label.shape[1])
        # 随机进行左右翻转
        flipCode = random.choice([0, 1])
        if flipCode == 0:
            image = self.augment(image)
            label = self.augment(label)
        return image, label[0]

    def __len__(self):
        return len(self.img_path)

    # 数据增强方法
    def augment(self, image):
        flip = cv2.flip(image, 1)
        return flip

1.3 处理label数据

因为SBD与VOC数据集的标注不太一样,需要统一一下
查了一下对应关系,按对应关系替换一下就好,不过要注意cv2.imread读取的是GBR图片而不是RBG

def cov_to_sbd(read_path, save_path, trans_list):
    for image_name in glob.glob(os.path.join(read_path, '*.png')):
        name = os.path.basename(image_name)
        image = cv2.imread(image_name)
        for i, row in enumerate(image):
            for j, line in enumerate(row):
                bgr = line.tolist()
                if (bgr != [0, 0, 0]) and (bgr != [192, 224, 224]):
                    rgb = bgr[::-1]
                    if rgb in trans_list:
                        num = 1 + trans_list.index(rgb)
                        image[i][j] = np.array([num, num, num])
                else:
                    image[i][j] = np.array([0, 0, 0])
        cv2.imwrite(os.path.join(save_path, name), image)
    print('finish')


if __name__ == '__main__':
    cov_path = r''
    sbd_path = r''
    trans_list = [[128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
                  [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
                  [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
                  [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
                  [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
    cov_to_sbd(read_path=cov_path, save_path=sbd_path, trans_list=trans_list)

1.4 训练

这里用的是DDP方法进行训练
执行命令

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py路径
def train_net(net, data_path, epochs=10, batch_size=2, lr=0.00001):
    # 分布式训练
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', default=-1, type=int,
                        help='node rank for distributed training')
    args = parser.parse_args()
    gpus = args.local_rank
    torch.cuda.set_device(gpus)
    dist.init_process_group(backend='nccl')
    net.cuda(gpus)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.local_rank], find_unused_parameters=True)
    # 加载数据
    isbi_dataset = COV2012_LOADER(data_path + 'train')
    train_sampler = torch.utils.data.distributed.DistributedSampler(isbi_dataset)
    train_loader = torch.utils.data.DataLoader(isbi_dataset, batch_size=batch_size, sampler=train_sampler, drop_last=True)
    val_dataset = COV2012_LOADER(data_path + 'val')
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler, drop_last=True)
    # 定义RMSprop算法
    optimizer = optim.RMSprop(net.parameters(), lr=lr, momentum=0.9)
    # 定义poly衰减
    lambda1 = lambda epo: (1 - epo / epochs) ** 0.9
    LambdaLR = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
    # 定义Loss算法
    criterion = nn.CrossEntropyLoss()
    # best_loss统计,初始化为正无穷inf
    best_loss = float('inf')
    # 记录loss, lr
    t_Loss, v_Loss, lr_list = [], [], []
    for epoch in range(epochs):
        try:
            # train
            net.train()
            loss_sum = 0
            # 按照batch_size开始训练
            for batch_index, data in enumerate(train_loader):
                image, label = data
                # 将数据拷贝到device中
                image = image.cuda(gpus, non_blocking=True)
                label = label.cuda(gpus, non_blocking=True)
                # 梯度清零
                optimizer.zero_grad()
                # 使用网络参数,输出预测结果
                pred = net(image.float())
                # 计算loss
                loss = criterion(pred, label.long())
                loss_sum += loss.item()
                # 更新参数
                loss.backward()
                optimizer.step()
                del image, label, pred, loss
                print('epoch:', epoch, ' index:', batch_index)
            # val
            val_loss = 0
            for image, label in val_loader:
                net.eval()
                image = image.cuda(gpus, non_blocking=True)
                label = label.cuda(gpus, non_blocking=True)
                optimizer.zero_grad()
                pred = net(image.float())
                loss = criterion(pred, label.long())
                val_loss += loss.item()
                del image, label, pred, loss
            avg_loss = loss_sum // len(train_loader)
            avg_val_loss = val_loss // len(val_loader)
            lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
            print('第%depoch, train_loss:%f, val_loss:%f, LR:%f' %
                  (epoch, avg_loss, avg_val_loss, optimizer.state_dict()['param_groups'][0]['lr']))
            # 更新LR
            LambdaLR.step()
            # 保存loss值最小的网络参数
            if avg_loss < best_loss:
                best_loss = avg_loss
                torch.save(net.state_dict(), '/sunny/deeplabv3+/best_model.pth')
        except RuntimeError as exception:
            if "out of memory" in str(exception):
                print("WARNING: out of memory")
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            else:
                raise exception
    print(t_Loss, v_Loss, lr_list)


if __name__ == "__main__":
    net = DeepLabV3Plus(
        n_classes=21,
        n_blocks=[3, 4, 23, 3],
        atrous_rates=[6, 12, 18],
        multi_grids=[1, 2, 4],
        output_stride=16
    )
    data_path = r"data/"
    train_net(net, data_path)

2. 训练

训练时遇到了显存不足CUDA out of memory问题,看了下主要的解决方案,总结了一下:

  • 减少batch_size(最后再试吧)
  • 验证集上使用net.eval禁止反向传播
  • 使用nn.ReLU(inplace=True)减少显存占用,原理是覆盖存储减少中间变量
  • 在显存不足时清理缓存(del与empty_cache一起使用)
try:
    net.train()
    ...
    del image, label, pred, loss 
except RuntimeError as exception:
    if "out of memory" in str(exception):
        print("WARNING: out of memory")
        if hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()
    else:
        raise exception

3. 后续待补充

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值