语义分割半监督训练方法(SST),以UniMatch(CVPR2023)为基础,详述原理、创新点及代码解析

一、背景

如今,深度学习技术的蓬勃发展在许多领域取得了显著成就,但同时也面临着一个不可忽视的瓶颈——对大规模标注数据的高度依赖。特别是在语义分割领域,模型需要像素级的标注数据,要求人工精确地标注每一张图像的每一个像素,这无疑是一个费时费力的过程。

1.1 标注数据的挑战

主要在于成本高昂,
标注一张高分辨率的语义分割图像需要较高精力和时间(而且一般要求的数据量较多),尤其是在复杂场景(如城市街景或医学影像)中,标注成本可能成倍增长。这种劳动密集型任务对标注员的专业技能提出了更高的要求,同时也显著提高了项目的总体成本。
如公共的Cityscapes数据集,需要一个实验室数年的数据收集和标注工作。
在这里插入图片描述

1.2 适用性限制

某些领域(例如医学影像分析、遥感解读)中,标注不仅需要大量时间,还需要领域专家的参与。这使得标注工作变得更加昂贵甚至难以实现。此外,在某些敏感领域,数据隐私问题也限制了标注工作的广泛开展。

1.3 数据稀缺性

在一些新兴或特殊的应用场景中,获取大规模标注数据几乎是不可能的,例如极端气候环境下的图像分割任务,或者特殊领域的小众数据集。

1.4 半监督学习的机遇

为了解决这些问题,**半监督学习(Semi-Supervised Learning, SSL)**提供了一种潜在的解决方案。它通过利用大量未标注数据,仅依赖少量标注数据来训练模型,降低了对大规模标注数据的依赖性。在语义分割中,这种方法特别适用,可以通过未标注图像的特性提取更多有价值的信息。

二、UniMatch(CVPR2023)半监督语义分割算法

2.1 原文链接

https://openaccess.thecvf.com/content/CVPR2023/papers/Yang_Revisiting_Weak-to-Strong_Consistency_in_Semi-Supervised_Semantic_Segmentation_CVPR_2023_paper.pdf
在这里插入图片描述

2.2 该论文的创新点汇总

2.2.1 背景

FixMatch是一种半监督分类模型,通过弱扰动图像的预测结果来监督强扰动图像的预测。这种方法在许多任务中表现优秀,但它的成功严重依赖手动设计的强数据增强方式,限制了扰动空间的广度。此外在FixMatch中所有的扰动都基于image-level,作者认为feature-level的扰动同样重要,可以增加模型的鲁棒性。

2.2.2 提出的改进

(1)扩展更广泛的扰动空间

引入了一个辅助特征扰动流(feature perturbation stream),以补充原始图像级扰动。
在弱扰动图像的特征层上施加扰动,实现图像和特征级别的一致性。

在图像输入后经encode提取到feature map后,对feature map进行扰动,再经decoder解码后,得到feature perturbation的p_fp。
在这里插入图片描述

(2)充分利用原始数据增强

开发了双流扰动技术(dual-stream perturbations),从预定义的图像级扰动池中随机生成两个强视图,利用共同的弱视图指导它们。
结合对比学习,获取更具区分性的特征表示。

2.3 取得的成果

算法在公共数据集上测试,精度较好

### 使用 UniMatch V2 进行自定义数据训练 #### 准备环境与依赖项 为了使用 UniMatch V2 对自定义数据集进行训练,首先需要准备合适的开发环境并安装必要的库。这通常涉及创建虚拟环境以及安装特定版本的 PyTorch 和其他依赖包。 ```bash conda create -n unimatch python=3.8 conda activate unimatch pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install -r requirements.txt ``` 上述命令假设读者已经配置好了 CUDA 环境,并且 `requirements.txt` 文件包含了 UniMatch 所需的所有 Python 库[^1]。 #### 数据预处理 对于视频匹配任务而言,输入的数据应当被整理成适合模型读取的形式。一般情况下,这意味着要将原始视频文件转换为图像序列,并按照一定规则命名这些图片文件以便于后续加载。此外,还需要生成相应的标签文件来指示每一对待比较帧之间的关系。 ```python import os from PIL import Image def preprocess_video(video_path, output_dir): cap = cv2.VideoCapture(video_path) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) if not os.path.exists(output_dir): os.makedirs(output_dir) count = 0 while True: ret, frame = cap.read() if not ret: break img_name = f"{output_dir}/frame_{count}.png" image = Image.fromarray(frame) image.save(img_name) count += 1 return {"total_frames": frame_count, "fps": fps} ``` 此段代码展示了如何从给定路径下的视频中提取每一帧作为单独的 PNG 图像保存到指定目录下。 #### 配置参数设置 在开始实际训练之前,还需调整一些超参数以适应新的数据分布特性。比如批量大小(batch size),迭代次数(iterations), 学习率(learning rate)等都可能影响最终效果的好坏。具体数值的选择往往取决于实验者的经验和初步测试的结果。 ```yaml train: batch_size: 8 iterations: 50000 learning_rate: 0.0001 dataset: path_to_train_data: "./data/train/" path_to_val_data: "./data/validation/" model: pretrained_weights: "unimatch_v2.pth" ``` 以上 YAML 片段给出了一个简单的配置模板,其中指定了训练过程中需要用到的关键参数及其默认值。 #### 启动训练过程 最后一步就是编写脚本来启动整个训练流程了。这里会涉及到调用前面提到过的各个组件——包括但不限于初始化网络结构、载入预训练权重、构建 dataloader 来提供 mini-batches 的样本等等。 ```python from unimatch.unimatch import UniMatch import yaml import torch.optim as optim from torch.utils.data import DataLoader from dataset.custom_dataset import CustomDataset with open('config.yaml', 'r') as file: config = yaml.safe_load(file) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = UniMatch().to(device) optimizer = optim.Adam(net.parameters(), lr=config['train']['learning_rate']) training_set = CustomDataset(config["dataset"]["path_to_train_data"]) validation_set = CustomDataset(config["dataset"]["path_to_val_data"]) dataloader_training = DataLoader(training_set, batch_size=config['train']['batch_size'], shuffle=True) dataloader_validation = DataLoader(validation_set, batch_size=config['train']['batch_size'], shuffle=False) for iteration in range(config['train']['iterations']): net.train() for i_batch, sample_batched in enumerate(dataloader_training): optimizer.zero_grad() loss = net(sample_batched['input_1'].to(device), sample_batched['input_2'].to(device)) loss.backward() optimizer.step() print(f"Iteration [{iteration}/{config['train']['iterations']}], Loss: {loss.item()}") print("Training completed.") ``` 这段程序片段实现了完整的训练循环逻辑,它能够周期性地更新模型参数直至达到预定的最大迭代数为止。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值