一、数据集准备
1.数据集格式
datasets/
├── ImageSets/
│ └── Segmentation/
│ ├── train.txt # 训练集文件名列表
│ ├── val.txt # 验证集文件名列表
│ └── test.txt # 测试集文件名列表
├── masks/ # 存放所有数据集掩码
│ ├── train/ # 训练集掩码
│ ├── val/ # 验证集掩码
│ └── test/ # 测试集掩码
├── dataset/ # YOLO格式数据集
├── train/
│ ├── images/ # 存储训练集图像
│ ├── labels/ # 存储训练集的 YOLO 格式标签
├── val/
│ ├── images/ # 存储验证集图像
│ ├── labels/ # 存储验证集的 YOLO 格式标签
└── test/
├── images/ # 存储测试集图像
├── labels/ # 存储测试集的 YOLO 格式标签
2.数据集目录结构详解
-
ImageSets/Segmentation/:
train.txt
:包含训练集中所有图像文件的文件名列表(不含扩展名)。这些文件名用于加载训练数据。文件名与dataset/train/images/
文件夹中的图像文件相对应。val.txt
:包含验证集中所有图像文件的文件名列表(不含扩展名)。这些文件名用于加载验证数据,文件名与dataset/val/images/
文件夹中的图像文件相对应。test.txt
:包含测试集中所有图像文件的文件名列表(不含扩展名)。这些文件名用于加载测试数据,文件名与dataset/test/images/
文件夹中的图像文件相对应。
-
masks/:
train/
:存放训练集的语义分割掩码图。这些掩码图通常为.png
格式,是与dataset/train/images/
中图像对应的单通道图像。掩码图中的每个像素值代表该像素的类别ID。val/
:存放验证集的语义分割掩码图,结构和用途与train/
文件夹相同。掩码图的文件名和格式要求与训练集一致,只是数据对应于验证集。test/
:存放测试集的语义分割掩码图,结构和用途与train/
文件夹相同。这些掩码图用于模型的最终评估。
-
dataset/:
train/
:images/
:存放训练集的原始图像文件。每个图像文件通常为.jpg
或.png
格式,用于模型的输入。图像文件名与masks/train/
中的掩码图文件名一致(扩展名除外)。train/
:labels/
:存放训练集的 YOLO 格式标签文件。每个标签文件的名称与对应的图像文件名相同,但扩展名为.txt
。标签文件描述了图像中所有对象的类别和边界框信息。标签格式如下:-
<object-class> <x_center> <y_center> <width> <height>
其中
<object-class>
表示对象的类别ID,<x_center>
和<y_center>
为边界框中心的坐标(归一化到0-1之间),<width>
和<height>
为边界框的宽度和高度(归一化到0-1之间)。 val/
:结构与train/
文件夹相同,包括images/
和labels/
子文件夹。验证集用于在训练过程中评估模型性能。test/
:结构与train/
文件夹相同,包括images/
和labels/
子文件夹。测试集用于模型的最终性能评估。
3.生成语义分割掩码图的详细步骤(masks)
在目标检测和语义分割任务中,我们通常会使用YOLO格式的标签进行对象检测,但对于语义分割任务,则需要将这些标签转换为掩码图。掩码图中的每个像素表示该像素所属的类别ID,这在训练语义分割模型时非常关键。下面我们详细解释如何将YOLO格式的标签转换为语义分割的掩码图,并提供相应的代码示例。
1. 掩码图生成原理
每个YOLO标签文件包含如下格式的标注:
<class_id> <x_center> <y_center> <width> <height>
<class_id>
:表示对象的类别ID(例如,0表示行人,1表示汽车)。<x_center>
和<y_center>
:表示对象边界框中心的x、y坐标,这些坐标值是相对于图像宽度和高度进行归一化的。<width>
和<height>
:表示对象边界框的宽度和高度,同样是归一化的。
我们的目标是将这些边界框信息转换为对应的掩码图,其中每个像素的值表示该像素所属的类别ID。
2. 实现步骤
-
读取YOLO标签文件:我们首先读取YOLO标签文件中的数据,包括对象的类别ID、边界框中心点坐标、宽度和高度。
-
转换归一化坐标为像素坐标:将YOLO标签中的归一化坐标(0到1之间)转换为实际图像的像素坐标。
-
生成掩码图:使用转换后的像素坐标,在掩码图上绘制矩形,矩形区域的像素值设置为类别ID。
-
保存掩码图:将生成的掩码图保存为图像文件,通常使用
.png
格式。
3. 运行代码生成掩码图
import cv2
import numpy as np
import os
# 类别列表与ID对应关系
class_names = ['', '', '', '', '', '', ''] #填入类别
def yolo_to_mask(label_file, img_file, num_classes=10): #修改类别数量
"""
将 YOLO 标签转换为语义分割掩码图。
参数:
label_file (str): YOLO 标签文件路径。
img_file (str): 对应图像文件路径。
num_classes (int): 类别数量。
返回:
mask (ndarray): 语义分割掩码图。
"""
# 读取图像以获取图像尺寸
img = cv2.imread(img_file)
if img is None:
print(f"无法读取图像文件: {img_file}")
return None
img_shape = img.shape
mask = np.zeros((img_shape[0], img_shape[1]), dtype=np.uint8)
with open(label_file, 'r') as f:
for line in f:
class_id, x_center, y_center, width, height = map(float, line.strip().split())
x_center *= img_shape[1] # 将中心点X坐标转换为像素
y_center *= img_shape[0] # 将中心点Y坐标转换为像素
width *= img_shape[1] # 将宽度转换为像素
height *= img_shape[0] # 将高度转换为像素
# 计算边界框的左上角和右下角
x1 = int(x_center - width / 2)
y1 = int(y_center - height / 2)
x2 = int(x_center + width / 2)
y2 = int(y_center + height / 2)
# 在掩码图上绘制矩形,class_id 作为像素值
cv2.rectangle(mask, (x1, y1), (x2, y2), int(class_id), -1)
return mask
# 示例使用
label_dir = './dataset/test/labels/' # 标签文件夹路径
image_dir = './dataset/test/images/' # 对应的图像文件夹路径
output_mask_dir = './masks/test/' # 保存掩码图的文件夹路径
os.makedirs(output_mask_dir, exist_ok=True)
# 假设标签和图片文件名相同,只是扩展名不同
for label_file in os.listdir(label_dir):
if label_file.endswith('.txt'):
img_file = label_file.replace('.txt', '.jpg') # 替换为图像文件的扩展名
img_path = os.path.join(image_dir, img_file)
label_path = os.path.join(label_dir, label_file)
if not os.path.exists(img_path):
print(f"图像文件不存在: {img_path}")
continue
mask = yolo_to_mask(label_path, img_path, num_classes=10)
if mask is not None:
output_mask_path = os.path.join(output_mask_dir, label_file.replace('.txt', '.png'))
cv2.imwrite(output_mask_path, mask) # 保存掩码图
print(f"成功生成掩码图: {output_mask_path}")
else:
print(f"掩码图生成失败: {label_path}")
首先,确保您有正确的YOLO标签文件和对应的图像文件。标签文件应与图像文件名匹配,只是扩展名不同(标签文件为 .txt
,图像文件为 .jpg
或 .png
)。运行上述代码,将生成的掩码图保存在指定的 output_mask_dir
文件夹中。每个生成的掩码图与原始图像尺寸一致,每个像素的值代表该位置的对象类别ID。通过这种方式,您可以将YOLO格式的标签数据转换为语义分割任务所需的掩码图,从而进行进一步的模型训练和验证。
4. 生成 ImageSets/Segmentation/
目录下的 train.txt
、val.txt
和 test.txt
文件
在语义分割任务中,ImageSets/Segmentation/
目录下的 train.txt
、val.txt
和 test.txt
文件用于列出训练集、验证集和测试集中的图像文件名。这些文件名会在模型训练和验证过程中被用来加载相应的数据。
下面是详细的步骤和代码,展示如何生成这些文件。
步骤概述
-
图像文件名获取:从指定的图像目录中获取所有图像文件名,并去掉扩展名(如
.jpg
或.png
),保留文件名部分。 -
创建输出目录:如果
ImageSets/Segmentation/
目录不存在,脚本会自动创建。 -
生成
.txt
文件:将处理后的图像文件名写入相应的.txt
文件中,例如train.txt
、val.txt
和test.txt
。 -
代码实现
import os def generate_txt(image_dir, output_path): """ 生成图像文件名列表的 .txt 文件。 参数: image_dir (str): 图像文件所在的目录。 output_path (str): 输出 .txt 文件的路径。 """ # 获取image_dir目录下所有的图像文件名,去除扩展名后保存到列表中 images = [f.split('.')[0] for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.png')] # 如果输出目录不存在,则创建 os.makedirs(os.path.dirname(output_path), exist_ok=True) # 将图像文件名写入输出文件 with open(output_path, 'w') as file: for img in images: file.write(f"{img}\n") print(f"{output_path} 文件生成完毕,共包含 {len(images)} 条记录。") if __name__ == "__main__": # 设置路径 train_image_dir = './datasets/dataset/train/images' # 训练集图像路径 val_image_dir = './datasets/dataset/val/images' # 验证集图像路径 test_image_dir = './datasets/dataset/test/images' # 测试集图像路径 output_dir = './datasets/ImageSets/Segmentation/' # 生成 .txt 文件的输出目录 # 生成 train.txt generate_txt(train_image_dir, os.path.join(output_dir, 'train.txt')) # 生成 val.txt generate_txt(val_image_dir, os.path.join(output_dir, 'val.txt')) # 生成 test.txt generate_txt(test_image_dir, os.path.join(output_dir, 'test.txt'))
二、deeplabv3+模型配置
1.在mypath.py
中添加自己的数据集名称与路径
class Path(object):
@staticmethod
def db_root_dir(dataset):
if dataset == 'pascal':
return '/path/to/datasets/VOCdevkit/VOC2012/' # folder that contains VOCdevkit/.
elif dataset == 'sbd':
return '/path/to/datasets/benchmark_RELEASE/' # folder that contains dataset/.
elif dataset == 'cityscapes':
return '/path/to/datasets/cityscapes/' # folder that contains leftImg8bit/
elif dataset == 'coco':
return '/path/to/datasets/coco/'
elif dataset == 'visdrone':
return '/deep/visdrone-mask/' # VisDrone dataset directory
else:
print('Dataset {} not available.'.format(dataset))
raise NotImplementedError
1. 扩展 Path 类以支持更多数据集
在深度学习项目中,随着研究的深入可能会使用多个不同的数据集进行实验。为了便于管理这些数据集的路径,可以在 Path
类中轻松扩展支持新的数据集。
假设需要添加一个新的数据集,名为 mynewdataset
。以下是具体的代码。
elif dataset == 'mynewdataset':
return '/path/to/mynewdataset/'
2.在同级目录中修改train.py
约185行添加自己数据集的名称(可以设置为默认)
parser.add_argument('--dataset', type=str, default='mynewdataset',
choices=['pascal', 'sbd', 'cityscapes', 'coco', 'visdrone', 'mynewdataset'],
help='Dataset name (default: mynewdataset)')
通过在 train.py
中添加新的数据集名称,并将其设置为默认值,可以确保在运行训练脚本时,新数据集会被默认加载。
3.在dataloaders目录下修改__init__.py
步骤 1:定位 __init__.py
文件
找到 dataloaders/
目录下的 __init__.py
文件。这是一个特殊的Python文件,用于将目录标识为Python包。通常,在这个文件中导入数据加载器,并根据需要暴露接口给外部使用。
步骤 2:导入新数据集加载器
在 __init__.py
文件中,需要导入并注册您的新数据集加载器。假设已经在 dataloaders
目录中创建了一个名为 mynewdataset.py
的文件,且在其中定义了 MyNewDataset
类。
打开 __init__.py
文件,并添加以下内容:
from dataloaders import mynewdataset # 导入新数据集加载器
步骤 3:更新数据集注册逻辑
为了支持新的数据集,需要在 make_data_loader
函数中添加对新数据集的处理逻辑。
def make_data_loader(args, **kwargs):
if args.dataset == 'pascal':
train_set = pascal.VOCSegmentation(args, split='train')
val_set = pascal.VOCSegmentation(args, split='val')
if args.use_sbd:
sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set])
num_class = train_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
test_loader = None
return train_loader, val_loader, test_loader, num_class
elif args.dataset == 'cityscapes':
train_set = cityscapes.CityscapesSegmentation(args, split='train')
val_set = cityscapes.CityscapesSegmentation(args, split='val')
test_set = cityscapes.CityscapesSegmentation(args, split='test')
num_class = train_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)
return train_loader, val_loader, test_loader, num_class
elif args.dataset == 'coco':
train_set = coco.COCOSegmentation(args, split='train')
val_set = coco.COCOSegmentation(args, split='val')
num_class = train_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
test_loader = None
return train_loader, val_loader, test_loader, num_class
elif args.dataset == 'visdrone':
train_set = visdrone.VisDroneSegmentation(args, split='train')
val_set = visdrone.VisDroneSegmentation(args, split='val')
test_set = visdrone.VisDroneSegmentation(args, split='test')
num_class = train_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)
return train_loader, val_loader, test_loader, num_class
elif args.dataset == 'mynewdataset':
train_set = mynewdataset.MyNewDataset(args, split='train')
val_set = mynewdataset.MyNewDataset(args, split='val')
test_set = mynewdataset.MyNewDataset(args, split='test')
num_class = train_set.NUM_CLASSES
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)
return train_loader, val_loader, test_loader, num_class
else:
raise NotImplementedError
4. 修改dateloaders目录下utils.py
要在 dataloaders
目录下修改 utils.py
文件,假设希望这个文件中的代码支持 mynewdataset
数据集,我们将通过以下几个步骤来进行修改。
步骤 1:导入必要的模块
首先,确保在 utils.py
文件的顶部导入必要的模块。如果还没有导入,请添加以下代码:
import numpy as np
步骤 2:定义 mynewdataset
的颜色映射
首先,您需要定义 mynewdataset
的颜色映射。假设 mynewdataset
有 10 个类别,每个类别对应一个特定的颜色:
def get_mynewdataset_labels():
"""定义 MyNewDataset 数据集的颜色映射"""
return np.array([
[128, 0, 0], # 类别 1
[0, 128, 0], # 类别 2
[128, 128, 0], # 类别 3
[0, 0, 128], # 类别 4
[128, 0, 128], # 类别 5
[0, 128, 128], # 类别 6
[128, 128, 128], # 类别 7
[64, 0, 0], # 类别 8
[192, 0, 0], # 类别 9
[64, 128, 0], # 类别 10
])
步骤 2:修改 decode_segmap
函数
确保 decode_segmap
函数能够处理 mynewdataset
,并正确地将分类掩码转换为颜色图:
def decode_segmap(label_mask, dataset, plot=False):
if dataset == 'pascal' or dataset == 'coco':
n_classes = 21
label_colours = get_pascal_labels()
elif dataset == 'visdrone':
n_classes = 10
label_colours = get_visdrone_labels()
elif dataset == 'cityscapes':
n_classes = 19
label_colours = get_cityscapes_labels()
elif dataset == 'mynewdataset':
n_classes = 10
label_colours = get_mynewdataset_labels()
else:
raise NotImplementedError
r = label_mask.copy()
g = label_mask.copy()
b = label_mask.copy()
for ll in range(0, n_classes):
r[label_mask == ll] = label_colours[ll, 0]
g[label_mask == ll] = label_colours[ll, 1]
b[label_mask == ll] = label_colours[ll, 2]
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
rgb[:, :, 0] = r / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
if plot:
plt.imshow(rgb)
plt.show()
else:
return rgb
步骤 3:修改 encode_segmap
函数
def encode_segmap(mask, dataset='pascal'):
"""Encode segmentation label images as class indices
Args:
mask (np.ndarray): raw segmentation label image of dimension
(M, N, 3), in which the dataset classes are encoded as colours.
Returns:
(np.ndarray): class map with dimensions (M,N), where the value at
a given location is the integer denoting the class index.
"""
mask = mask.astype(int)
label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
if dataset == 'pascal':
labels = get_pascal_labels()
elif dataset == 'cityscapes':
labels = get_cityscapes_labels()
elif dataset == 'visdrone':
labels = get_visdrone_labels()
elif dataset == 'mynewdataset':
labels = get_mynewdataset_labels()
else:
raise NotImplementedError
for ii, label in enumerate(labels):
label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
return label_mask
5.在dataloaders/datasets目录下添加文件
在 dataloaders/datasets/
目录下创建一个新的 Python 文件,例如 mynewdataset.py
。
from __future__ import print_function, division
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from mypath import Path
from torchvision import transforms
from dataloaders import custom_transforms as tr
class MyNewDataset(Dataset):
"""
MyNewDataset dataset
"""
NUM_CLASSES = 10 # 根据你的数据集类别数量修改
def __init__(self,
args,
base_dir='./mynewdataset', # 设置你的数据集根目录
split='train',
):
super().__init__()
# 根据你的数据集结构修改路径
self._base_dir = base_dir
self._image_dir = os.path.join(self._base_dir, 'images', split) # 修改为实际路径
self._cat_dir = os.path.join(self._base_dir, 'masks', split) # 修改为实际路径
if isinstance(split, str):
self.split = [split]
else:
split.sort()
self.split = split
self.args = args
# 确保此路径指向你的 ImageSets 文件夹
_splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation')
self.im_ids = []
self.images = []
self.categories = []
for splt in self.split:
with open(os.path.join(_splits_dir, splt + '.txt'), "r") as f:
lines = f.read().splitlines()
for ii, line in enumerate(lines):
_image = os.path.join(self._image_dir, line + ".jpg") # 修改为实际的图像扩展名
_cat = os.path.join(self._cat_dir, line + ".png") # 修改为实际的掩码扩展名
assert os.path.isfile(_image), f"Image file not found: {_image}"
assert os.path.isfile(_cat), f"Mask file not found: {_cat}"
self.im_ids.append(line)
self.images.append(_image)
self.categories.append(_cat)
assert (len(self.images) == len(self.categories))
print('Number of images in {}: {:d}'.format(split, len(self.images)))
def __len__(self):
return len(self.images)
def __getitem__(self, index):
_img, _target = self._make_img_gt_point_pair(index)
sample = {'image': _img, 'label': _target}
for split in self.split:
if split == "train":
return self.transform_tr(sample)
elif split == 'val':
return self.transform_val(sample)
def _make_img_gt_point_pair(self, index):
_img = Image.open(self.images[index]).convert('RGB')
_target = Image.open(self.categories[index])
return _img, _target
def transform_tr(self, sample):
composed_transforms = transforms.Compose([
tr.RandomHorizontalFlip(),
tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
tr.RandomGaussianBlur(),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.ToTensor()])
return composed_transforms(sample)
def transform_val(self, sample):
composed_transforms = transforms.Compose([
tr.FixScaleCrop(crop_size=self.args.crop_size),
tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
tr.ToTensor()])
return composed_transforms(sample)
def __str__(self):
return 'MyNewDataset(split=' + str(self.split) + ')'
6. 运行并训练
python train.py --backbone mobilenet --lr 0.007 --workers 1 --epochs 50 --batch-size 8 --gpu-ids 0 --checkname deeplab-mobilenet
–backbone mobilenet 指的是使用mobilenet作为backbone
–gpu-ids 0 指定gpu
–checkname deeplab-mobilenet 使用mobilenet预训练模型
7. 测试
测试testdemo.py
修改–in-path为数据集的测试图片,最后的结果保存在–out-path中
--in-path
设置为./datasets/dataset/test/images
,这是包含原始测试图像的目录。--out-path
设置为./datasets/dataset/test/output
,这是希望保存生成的语义分割结果图像的目录。- run/dataset/deeplab-mobilenet/model_best.pth.tar,这是训练好的权重。
python testdemo.py --ckpt run/visdrone/deeplab-mobilenet/model_best.pth.tar --backbone mobilenet --in-path ./visdrone-mask/VisDrone2019/test/images --out-path ./visdrone-mask/VisDrone2019/test/output
以下是testdemo.py的代码
import argparse
import os
import numpy as np
import time
from modeling.deeplab import *
from dataloaders import custom_transforms as tr
from PIL import Image
from torchvision import transforms
from dataloaders.utils import *
from torchvision.utils import save_image
def main():
parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Inference")
parser.add_argument('--in-path', type=str, required=True,
help='Path to input images for inference')
parser.add_argument('--out-path', type=str, required=True,
help='Path to save the output segmentation maps')
parser.add_argument('--backbone', type=str, default='mobilenet',
choices=['resnet', 'xception', 'drn', 'mobilenet'],
help='Backbone model used in DeepLabV3 (default: mobilenet)')
parser.add_argument('--ckpt', type=str, default='./run/visdrone/deeplab-mobilenet/model_best.pth.tar',
help='Path to the saved model checkpoint')
parser.add_argument('--out-stride', type=int, default=16,
help='Network output stride (default: 16)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='If set, disables CUDA training')
parser.add_argument('--gpu-ids', type=str, default='0',
help='Comma-separated list of GPU IDs to use (default: 0)')
parser.add_argument('--dataset', type=str, default='visdrone',
choices=['pascal', 'coco', 'cityscapes', 'visdrone'],
help='Dataset name (default: visdrone)')
parser.add_argument('--crop-size', type=int, default=513,
help='Crop size for inference (default: 513)')
parser.add_argument('--num_classes', type=int, default=10,
help='Number of classes (default: 10 for VisDrone)')
parser.add_argument('--sync-bn', type=bool, default=None,
help='Whether to use synchronized batch normalization')
parser.add_argument('--freeze-bn', type=bool, default=False,
help='If set, freezes batch normalization parameters')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
try:
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
except ValueError:
raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')
if args.sync_bn is None:
args.sync_bn = args.cuda and len(args.gpu_ids) > 1
# Load the model
model = DeepLab(num_classes=args.num_classes,
backbone=args.backbone,
output_stride=args.out_stride,
sync_bn=args.sync_bn,
freeze_bn=args.freeze_bn)
ckpt = torch.load(args.ckpt, map_location='cpu')
model.load_state_dict(ckpt['state_dict'])
model = model.cuda() if args.cuda else model
model.eval()
# Updated transformation pipeline for inference (only apply to images)
composed_transforms = transforms.Compose([
transforms.Resize((args.crop_size, args.crop_size)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
if not os.path.exists(args.out_path):
os.makedirs(args.out_path)
# Inference
for name in os.listdir(args.in_path):
if name.startswith("."): # Skip hidden files like .ipynb_checkpoints
continue
image_path = os.path.join(args.in_path, name)
image = Image.open(image_path).convert('RGB')
tensor_in = composed_transforms(image).unsqueeze(0)
if args.cuda:
tensor_in = tensor_in.cuda()
with torch.no_grad():
output = model(tensor_in)
seg_map = torch.max(output, 1)[1].detach().cpu().numpy()
seg_map = decode_segmap(seg_map[0], dataset=args.dataset)
output_image_path = os.path.join(args.out_path, name.replace(".jpg", "_mask.png"))
save_image(torch.tensor(seg_map).permute(2, 0, 1), output_image_path)
print(f"Processed {name}, saved segmentation map to {output_image_path}")
if __name__ == "__main__":
main()