Pytorch版deeplabv3+训练自己的数据集

一、数据集准备

1.数据集格式

datasets/
    ├── ImageSets/
    │   └── Segmentation/
    │       ├── train.txt    # 训练集文件名列表
    │       ├── val.txt      # 验证集文件名列表
    │       └── test.txt     # 测试集文件名列表
    ├── masks/               # 存放所有数据集掩码
    │   ├── train/           # 训练集掩码
    │   ├── val/             # 验证集掩码
    │   └── test/            # 测试集掩码
    ├── dataset/             # YOLO格式数据集
        ├── train/
        │   ├── images/      # 存储训练集图像
        │   ├── labels/      # 存储训练集的 YOLO 格式标签
        ├── val/
        │   ├── images/      # 存储验证集图像
        │   ├── labels/      # 存储验证集的 YOLO 格式标签
        └── test/
            ├── images/      # 存储测试集图像
            ├── labels/      # 存储测试集的 YOLO 格式标签

2.数据集目录结构详解

  1. ImageSets/Segmentation/

    • train.txt:包含训练集中所有图像文件的文件名列表(不含扩展名)。这些文件名用于加载训练数据。文件名与 dataset/train/images/ 文件夹中的图像文件相对应。
    • val.txt:包含验证集中所有图像文件的文件名列表(不含扩展名)。这些文件名用于加载验证数据,文件名与 dataset/val/images/ 文件夹中的图像文件相对应。
    • test.txt:包含测试集中所有图像文件的文件名列表(不含扩展名)。这些文件名用于加载测试数据,文件名与 dataset/test/images/ 文件夹中的图像文件相对应。
  2. masks/

    • train/:存放训练集的语义分割掩码图。这些掩码图通常为 .png 格式,是与 dataset/train/images/ 中图像对应的单通道图像。掩码图中的每个像素值代表该像素的类别ID。
    • val/:存放验证集的语义分割掩码图,结构和用途与 train/ 文件夹相同。掩码图的文件名和格式要求与训练集一致,只是数据对应于验证集。
    • test/:存放测试集的语义分割掩码图,结构和用途与 train/ 文件夹相同。这些掩码图用于模型的最终评估。
  3.  dataset/

    • train/images/:存放训练集的原始图像文件。每个图像文件通常为 .jpg.png 格式,用于模型的输入。图像文件名与 masks/train/ 中的掩码图文件名一致(扩展名除外)。
    • train/labels/:存放训练集的 YOLO 格式标签文件。每个标签文件的名称与对应的图像文件名相同,但扩展名为 .txt。标签文件描述了图像中所有对象的类别和边界框信息。标签格式如下:
    • <object-class> <x_center> <y_center> <width> <height>
      

      其中 <object-class> 表示对象的类别ID,<x_center><y_center> 为边界框中心的坐标(归一化到0-1之间),<width><height> 为边界框的宽度和高度(归一化到0-1之间)。

    • val/:结构与 train/ 文件夹相同,包括 images/labels/ 子文件夹。验证集用于在训练过程中评估模型性能。
    • test/:结构与 train/ 文件夹相同,包括 images/labels/ 子文件夹。测试集用于模型的最终性能评估。

3.生成语义分割掩码图的详细步骤(masks)

在目标检测和语义分割任务中,我们通常会使用YOLO格式的标签进行对象检测,但对于语义分割任务,则需要将这些标签转换为掩码图。掩码图中的每个像素表示该像素所属的类别ID,这在训练语义分割模型时非常关键。下面我们详细解释如何将YOLO格式的标签转换为语义分割的掩码图,并提供相应的代码示例。

1. 掩码图生成原理

每个YOLO标签文件包含如下格式的标注:

<class_id> <x_center> <y_center> <width> <height>
  • <class_id>:表示对象的类别ID(例如,0表示行人,1表示汽车)。
  • <x_center> <y_center>:表示对象边界框中心的x、y坐标,这些坐标值是相对于图像宽度和高度进行归一化的。
  • <width><height>:表示对象边界框的宽度和高度,同样是归一化的。

我们的目标是将这些边界框信息转换为对应的掩码图,其中每个像素的值表示该像素所属的类别ID。

2. 实现步骤

  1. 读取YOLO标签文件:我们首先读取YOLO标签文件中的数据,包括对象的类别ID、边界框中心点坐标、宽度和高度。

  2. 转换归一化坐标为像素坐标:将YOLO标签中的归一化坐标(0到1之间)转换为实际图像的像素坐标。

  3. 生成掩码图:使用转换后的像素坐标,在掩码图上绘制矩形,矩形区域的像素值设置为类别ID。

  4. 保存掩码图:将生成的掩码图保存为图像文件,通常使用.png格式。

3. 运行代码生成掩码图

import cv2
import numpy as np
import os

# 类别列表与ID对应关系
class_names = ['', '', '', '', '', '', '']  #填入类别

def yolo_to_mask(label_file, img_file, num_classes=10): #修改类别数量
    """
    将 YOLO 标签转换为语义分割掩码图。

    参数:
    label_file (str): YOLO 标签文件路径。
    img_file (str): 对应图像文件路径。
    num_classes (int): 类别数量。

    返回:
    mask (ndarray): 语义分割掩码图。
    """
    # 读取图像以获取图像尺寸
    img = cv2.imread(img_file)
    if img is None:
        print(f"无法读取图像文件: {img_file}")
        return None
    
    img_shape = img.shape
    mask = np.zeros((img_shape[0], img_shape[1]), dtype=np.uint8)
    
    with open(label_file, 'r') as f:
        for line in f:
            class_id, x_center, y_center, width, height = map(float, line.strip().split())
            x_center *= img_shape[1]  # 将中心点X坐标转换为像素
            y_center *= img_shape[0]  # 将中心点Y坐标转换为像素
            width *= img_shape[1]     # 将宽度转换为像素
            height *= img_shape[0]    # 将高度转换为像素

            # 计算边界框的左上角和右下角
            x1 = int(x_center - width / 2)
            y1 = int(y_center - height / 2)
            x2 = int(x_center + width / 2)
            y2 = int(y_center + height / 2)

            # 在掩码图上绘制矩形,class_id 作为像素值
            cv2.rectangle(mask, (x1, y1), (x2, y2), int(class_id), -1)
    
    return mask

# 示例使用
label_dir = './dataset/test/labels/'  # 标签文件夹路径
image_dir = './dataset/test/images/'  # 对应的图像文件夹路径
output_mask_dir = './masks/test/'  # 保存掩码图的文件夹路径

os.makedirs(output_mask_dir, exist_ok=True)

# 假设标签和图片文件名相同,只是扩展名不同
for label_file in os.listdir(label_dir):
    if label_file.endswith('.txt'):
        img_file = label_file.replace('.txt', '.jpg')  # 替换为图像文件的扩展名
        img_path = os.path.join(image_dir, img_file)
        label_path = os.path.join(label_dir, label_file)
        
        if not os.path.exists(img_path):
            print(f"图像文件不存在: {img_path}")
            continue
        
        mask = yolo_to_mask(label_path, img_path, num_classes=10)
        
        if mask is not None:
            output_mask_path = os.path.join(output_mask_dir, label_file.replace('.txt', '.png'))
            cv2.imwrite(output_mask_path, mask)  # 保存掩码图
            print(f"成功生成掩码图: {output_mask_path}")
        else:
            print(f"掩码图生成失败: {label_path}")

首先,确保您有正确的YOLO标签文件和对应的图像文件。标签文件应与图像文件名匹配,只是扩展名不同(标签文件为 .txt,图像文件为 .jpg.png)。运行上述代码,将生成的掩码图保存在指定的 output_mask_dir 文件夹中。每个生成的掩码图与原始图像尺寸一致,每个像素的值代表该位置的对象类别ID。通过这种方式,您可以将YOLO格式的标签数据转换为语义分割任务所需的掩码图,从而进行进一步的模型训练和验证。

4. 生成 ImageSets/Segmentation/ 目录下的 train.txtval.txttest.txt 文件

在语义分割任务中,ImageSets/Segmentation/ 目录下的 train.txtval.txttest.txt 文件用于列出训练集、验证集和测试集中的图像文件名。这些文件名会在模型训练和验证过程中被用来加载相应的数据。

下面是详细的步骤和代码,展示如何生成这些文件。

步骤概述

  1. 图像文件名获取:从指定的图像目录中获取所有图像文件名,并去掉扩展名(如 .jpg.png),保留文件名部分。

  2. 创建输出目录:如果 ImageSets/Segmentation/ 目录不存在,脚本会自动创建。

  3. 生成 .txt 文件:将处理后的图像文件名写入相应的 .txt 文件中,例如 train.txtval.txttest.txt

  4. 代码实现

    import os
    
    def generate_txt(image_dir, output_path):
        """
        生成图像文件名列表的 .txt 文件。
    
        参数:
        image_dir (str): 图像文件所在的目录。
        output_path (str): 输出 .txt 文件的路径。
        """
        # 获取image_dir目录下所有的图像文件名,去除扩展名后保存到列表中
        images = [f.split('.')[0] for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.png')]
        
        # 如果输出目录不存在,则创建
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
        # 将图像文件名写入输出文件
        with open(output_path, 'w') as file:
            for img in images:
                file.write(f"{img}\n")
    
        print(f"{output_path} 文件生成完毕,共包含 {len(images)} 条记录。")
    
    if __name__ == "__main__":
        # 设置路径
        train_image_dir = './datasets/dataset/train/images'  # 训练集图像路径
        val_image_dir = './datasets/dataset/val/images'  # 验证集图像路径
        test_image_dir = './datasets/dataset/test/images'  # 测试集图像路径
    
        output_dir = './datasets/ImageSets/Segmentation/'  # 生成 .txt 文件的输出目录
    
        # 生成 train.txt
        generate_txt(train_image_dir, os.path.join(output_dir, 'train.txt'))
        
        # 生成 val.txt
        generate_txt(val_image_dir, os.path.join(output_dir, 'val.txt'))
        
        # 生成 test.txt
        generate_txt(test_image_dir, os.path.join(output_dir, 'test.txt'))
    

    二、deeplabv3+模型配置

1.在mypath.py中添加自己的数据集名称与路径

class Path(object):
    @staticmethod
    def db_root_dir(dataset):
        if dataset == 'pascal':
            return '/path/to/datasets/VOCdevkit/VOC2012/'  # folder that contains VOCdevkit/.
        elif dataset == 'sbd':
            return '/path/to/datasets/benchmark_RELEASE/'  # folder that contains dataset/.
        elif dataset == 'cityscapes':
            return '/path/to/datasets/cityscapes/'     # folder that contains leftImg8bit/
        elif dataset == 'coco':
            return '/path/to/datasets/coco/'
        elif dataset == 'visdrone':
            return '/deep/visdrone-mask/'  # VisDrone dataset directory
        else:
            print('Dataset {} not available.'.format(dataset))
            raise NotImplementedError

1. 扩展 Path 类以支持更多数据集

在深度学习项目中,随着研究的深入可能会使用多个不同的数据集进行实验。为了便于管理这些数据集的路径,可以在 Path 类中轻松扩展支持新的数据集。

假设需要添加一个新的数据集,名为 mynewdataset。以下是具体的代码。

elif dataset == 'mynewdataset':
    return '/path/to/mynewdataset/'

2.在同级目录中修改train.py约185行添加自己数据集的名称(可以设置为默认)

parser.add_argument('--dataset', type=str, default='mynewdataset',
                    choices=['pascal', 'sbd', 'cityscapes', 'coco', 'visdrone', 'mynewdataset'],
                    help='Dataset name (default: mynewdataset)')

通过在 train.py 中添加新的数据集名称,并将其设置为默认值,可以确保在运行训练脚本时,新数据集会被默认加载。

3.在dataloaders目录下修改__init__.py

步骤 1:定位 __init__.py 文件

找到 dataloaders/ 目录下的 __init__.py 文件。这是一个特殊的Python文件,用于将目录标识为Python包。通常,在这个文件中导入数据加载器,并根据需要暴露接口给外部使用。

步骤 2:导入新数据集加载器

__init__.py 文件中,需要导入并注册您的新数据集加载器。假设已经在 dataloaders 目录中创建了一个名为 mynewdataset.py 的文件,且在其中定义了 MyNewDataset 类。

打开 __init__.py 文件,并添加以下内容:

from dataloaders import mynewdataset  # 导入新数据集加载器

步骤 3:更新数据集注册逻辑

为了支持新的数据集,需要在 make_data_loader 函数中添加对新数据集的处理逻辑。

def make_data_loader(args, **kwargs):

    if args.dataset == 'pascal':
        train_set = pascal.VOCSegmentation(args, split='train')
        val_set = pascal.VOCSegmentation(args, split='val')
        if args.use_sbd:
            sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
            train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set])

        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = None

        return train_loader, val_loader, test_loader, num_class

    elif args.dataset == 'cityscapes':
        train_set = cityscapes.CityscapesSegmentation(args, split='train')
        val_set = cityscapes.CityscapesSegmentation(args, split='val')
        test_set = cityscapes.CityscapesSegmentation(args, split='test')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)

        return train_loader, val_loader, test_loader, num_class

    elif args.dataset == 'coco':
        train_set = coco.COCOSegmentation(args, split='train')
        val_set = coco.COCOSegmentation(args, split='val')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = None
        return train_loader, val_loader, test_loader, num_class
    
    elif args.dataset == 'visdrone':
        train_set = visdrone.VisDroneSegmentation(args, split='train')
        val_set = visdrone.VisDroneSegmentation(args, split='val')
        test_set = visdrone.VisDroneSegmentation(args, split='test')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)

        return train_loader, val_loader, test_loader, num_class

    elif args.dataset == 'mynewdataset':
        train_set = mynewdataset.MyNewDataset(args, split='train')
        val_set = mynewdataset.MyNewDataset(args, split='val')
        test_set = mynewdataset.MyNewDataset(args, split='test')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)

        return train_loader, val_loader, test_loader, num_class

    else:
        raise NotImplementedError

4. 修改dateloaders目录下utils.py

要在 dataloaders 目录下修改 utils.py 文件,假设希望这个文件中的代码支持 mynewdataset 数据集,我们将通过以下几个步骤来进行修改。

步骤 1:导入必要的模块

首先,确保在 utils.py 文件的顶部导入必要的模块。如果还没有导入,请添加以下代码:

import numpy as np

步骤 2:定义 mynewdataset 的颜色映射

首先,您需要定义 mynewdataset 的颜色映射。假设 mynewdataset 有 10 个类别,每个类别对应一个特定的颜色:

def get_mynewdataset_labels():
    """定义 MyNewDataset 数据集的颜色映射"""
    return np.array([
        [128, 0, 0],   # 类别 1
        [0, 128, 0],   # 类别 2
        [128, 128, 0], # 类别 3
        [0, 0, 128],   # 类别 4
        [128, 0, 128], # 类别 5
        [0, 128, 128], # 类别 6
        [128, 128, 128], # 类别 7
        [64, 0, 0],   # 类别 8
        [192, 0, 0],  # 类别 9
        [64, 128, 0], # 类别 10
    ])

步骤 2:修改 decode_segmap 函数

确保 decode_segmap 函数能够处理 mynewdataset,并正确地将分类掩码转换为颜色图:

def decode_segmap(label_mask, dataset, plot=False):
    if dataset == 'pascal' or dataset == 'coco':
        n_classes = 21
        label_colours = get_pascal_labels()
    elif dataset == 'visdrone':
        n_classes = 10
        label_colours = get_visdrone_labels()
    elif dataset == 'cityscapes':
        n_classes = 19
        label_colours = get_cityscapes_labels()
    elif dataset == 'mynewdataset':
        n_classes = 10
        label_colours = get_mynewdataset_labels()
    else:
        raise NotImplementedError

    r = label_mask.copy()
    g = label_mask.copy()
    b = label_mask.copy()
    for ll in range(0, n_classes):
        r[label_mask == ll] = label_colours[ll, 0]
        g[label_mask == ll] = label_colours[ll, 1]
        b[label_mask == ll] = label_colours[ll, 2]
    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
    rgb[:, :, 0] = r / 255.0
    rgb[:, :, 1] = g / 255.0
    rgb[:, :, 2] = b / 255.0
    if plot:
        plt.imshow(rgb)
        plt.show()
    else:
        return rgb

步骤 3:修改 encode_segmap 函数

def encode_segmap(mask, dataset='pascal'):
    """Encode segmentation label images as class indices
    Args:
        mask (np.ndarray): raw segmentation label image of dimension
          (M, N, 3), in which the dataset classes are encoded as colours.
    Returns:
        (np.ndarray): class map with dimensions (M,N), where the value at
        a given location is the integer denoting the class index.
    """
    mask = mask.astype(int)
    label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)

    if dataset == 'pascal':
        labels = get_pascal_labels()
    elif dataset == 'cityscapes':
        labels = get_cityscapes_labels()
    elif dataset == 'visdrone':
        labels = get_visdrone_labels()
    elif dataset == 'mynewdataset':
        labels = get_mynewdataset_labels()
    else:
        raise NotImplementedError
    
    for ii, label in enumerate(labels):
        label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
    return label_mask

5.在dataloaders/datasets目录下添加文件

dataloaders/datasets/ 目录下创建一个新的 Python 文件,例如 mynewdataset.py

from __future__ import print_function, division
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from mypath import Path
from torchvision import transforms
from dataloaders import custom_transforms as tr

class MyNewDataset(Dataset):
    """
    MyNewDataset dataset
    """
    NUM_CLASSES = 10  # 根据你的数据集类别数量修改

    def __init__(self,
                 args,
                 base_dir='./mynewdataset',  # 设置你的数据集根目录
                 split='train',
                 ):
        super().__init__()
        # 根据你的数据集结构修改路径
        self._base_dir = base_dir
        self._image_dir = os.path.join(self._base_dir, 'images', split)  # 修改为实际路径
        self._cat_dir = os.path.join(self._base_dir, 'masks', split)  # 修改为实际路径

        if isinstance(split, str):
            self.split = [split]
        else:
            split.sort()
            self.split = split

        self.args = args

        # 确保此路径指向你的 ImageSets 文件夹
        _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation')

        self.im_ids = []
        self.images = []
        self.categories = []

        for splt in self.split:
            with open(os.path.join(_splits_dir, splt + '.txt'), "r") as f:
                lines = f.read().splitlines()

            for ii, line in enumerate(lines):
                _image = os.path.join(self._image_dir, line + ".jpg")  # 修改为实际的图像扩展名
                _cat = os.path.join(self._cat_dir, line + ".png")  # 修改为实际的掩码扩展名
                assert os.path.isfile(_image), f"Image file not found: {_image}"
                assert os.path.isfile(_cat), f"Mask file not found: {_cat}"
                self.im_ids.append(line)
                self.images.append(_image)
                self.categories.append(_cat)

        assert (len(self.images) == len(self.categories))

        print('Number of images in {}: {:d}'.format(split, len(self.images)))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        _img, _target = self._make_img_gt_point_pair(index)
        sample = {'image': _img, 'label': _target}

        for split in self.split:
            if split == "train":
                return self.transform_tr(sample)
            elif split == 'val':
                return self.transform_val(sample)

    def _make_img_gt_point_pair(self, index):
        _img = Image.open(self.images[index]).convert('RGB')
        _target = Image.open(self.categories[index])

        return _img, _target

    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
            tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

    def transform_val(self, sample):
        composed_transforms = transforms.Compose([
            tr.FixScaleCrop(crop_size=self.args.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()])

        return composed_transforms(sample)

    def __str__(self):
        return 'MyNewDataset(split=' + str(self.split) + ')'

6. 运行并训练

python train.py --backbone mobilenet --lr 0.007 --workers 1 --epochs 50 --batch-size 8 --gpu-ids 0 --checkname deeplab-mobilenet

–backbone mobilenet 指的是使用mobilenet作为backbone
–gpu-ids 0 指定gpu
–checkname deeplab-mobilenet 使用mobilenet预训练模型

7. 测试

测试testdemo.py
修改–in-path为数据集的测试图片,最后的结果保存在–out-path中

  • --in-path 设置为 ./datasets/dataset/test/images,这是包含原始测试图像的目录。
  • --out-path 设置为 ./datasets/dataset/test/output,这是希望保存生成的语义分割结果图像的目录。
  • run/dataset/deeplab-mobilenet/model_best.pth.tar,这是训练好的权重。
python testdemo.py --ckpt run/visdrone/deeplab-mobilenet/model_best.pth.tar --backbone mobilenet --in-path ./visdrone-mask/VisDrone2019/test/images --out-path ./visdrone-mask/VisDrone2019/test/output

以下是testdemo.py的代码

import argparse
import os
import numpy as np
import time
from modeling.deeplab import *
from dataloaders import custom_transforms as tr
from PIL import Image
from torchvision import transforms
from dataloaders.utils import *
from torchvision.utils import save_image

def main():
    parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Inference")
    parser.add_argument('--in-path', type=str, required=True,
                        help='Path to input images for inference')
    parser.add_argument('--out-path', type=str, required=True, 
                        help='Path to save the output segmentation maps')
    parser.add_argument('--backbone', type=str, default='mobilenet',
                        choices=['resnet', 'xception', 'drn', 'mobilenet'],
                        help='Backbone model used in DeepLabV3 (default: mobilenet)')
    parser.add_argument('--ckpt', type=str, default='./run/visdrone/deeplab-mobilenet/model_best.pth.tar',
                        help='Path to the saved model checkpoint')
    parser.add_argument('--out-stride', type=int, default=16,
                        help='Network output stride (default: 16)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='If set, disables CUDA training')
    parser.add_argument('--gpu-ids', type=str, default='0',
                        help='Comma-separated list of GPU IDs to use (default: 0)')
    parser.add_argument('--dataset', type=str, default='visdrone',
                        choices=['pascal', 'coco', 'cityscapes', 'visdrone'],
                        help='Dataset name (default: visdrone)')
    parser.add_argument('--crop-size', type=int, default=513,
                        help='Crop size for inference (default: 513)')
    parser.add_argument('--num_classes', type=int, default=10,
                        help='Number of classes (default: 10 for VisDrone)')
    parser.add_argument('--sync-bn', type=bool, default=None,
                        help='Whether to use synchronized batch normalization')
    parser.add_argument('--freeze-bn', type=bool, default=False,
                        help='If set, freezes batch normalization parameters')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        try:
            args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
        except ValueError:
            raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')

    if args.sync_bn is None:
        args.sync_bn = args.cuda and len(args.gpu_ids) > 1

    # Load the model
    model = DeepLab(num_classes=args.num_classes,
                    backbone=args.backbone,
                    output_stride=args.out_stride,
                    sync_bn=args.sync_bn,
                    freeze_bn=args.freeze_bn)

    ckpt = torch.load(args.ckpt, map_location='cpu')
    model.load_state_dict(ckpt['state_dict'])
    model = model.cuda() if args.cuda else model
    model.eval()

    # Updated transformation pipeline for inference (only apply to images)
    composed_transforms = transforms.Compose([
        transforms.Resize((args.crop_size, args.crop_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])

    if not os.path.exists(args.out_path):
        os.makedirs(args.out_path)

    # Inference
    for name in os.listdir(args.in_path):
        if name.startswith("."):  # Skip hidden files like .ipynb_checkpoints
            continue

        image_path = os.path.join(args.in_path, name)
        image = Image.open(image_path).convert('RGB')
        tensor_in = composed_transforms(image).unsqueeze(0)

        if args.cuda:
            tensor_in = tensor_in.cuda()

        with torch.no_grad():
            output = model(tensor_in)

        seg_map = torch.max(output, 1)[1].detach().cpu().numpy()
        seg_map = decode_segmap(seg_map[0], dataset=args.dataset)
        output_image_path = os.path.join(args.out_path, name.replace(".jpg", "_mask.png"))
        save_image(torch.tensor(seg_map).permute(2, 0, 1), output_image_path)

        print(f"Processed {name}, saved segmentation map to {output_image_path}")

if __name__ == "__main__":
    main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值