[极简]pytorch版Unet训练自己的数据集

本文指导如何准备VOC数据集,包括图片和标注组织,然后详细介绍了使用unet.py进行训练的参数调整,如类别数、Dice Loss、预训练权重等。特别关注了批量大小对内存需求的影响。最后介绍了测试阶段的predict.py和模型预测功能。
摘要由CSDN通过智能技术生成

一、准备数据集

VOCdevkit
    VOC2007
         JPEGImages
         SegmentationClass
         ImageSets
              Segmentation
                   test.txt
                   train.txt
                   trainval.txt
                   val.txt

把所有jpg原图片放到JPEGImages
把所有png标注图片放到SegmentationClass
运行voc2unet.py脚本

二、训练

train.py
关键参数
1、num_classes=(类别数+1)
2、dice_loss=True/False
# 种类少(几类)时,设置为True
# 种类多(十几类)时,如果batch_size比较大(10以上),那么设置为True
# 种类多(十几类)时,如果batch_size比较小(10以下),那么设置为False
3、pretraind=True/False(是否使用预训练权重)
model_path = r"model_data/unet_voc.pth"(预训练权重路径)
4、lr = 1e-4
Init_Epoch = 0
Interval_Epoch = 25
Batch_size = 2

lr = 1e-5
Interval_Epoch = 25
Epoch = 50
Batch_size = 2
主干特征提取网络特征通用,冻结训练可以加快训练速度
也可以在训练初期防止权值被破坏。
Init_Epoch为起始世代
Interval_Epoch为冻结训练的世代
Epoch总训练世代
提示OOM或者显存不足请调小Batch_size

三、测试

predict.py(可自动保存)

unet.py
video.py

对于使用 PyTorch 训练自己的数据集,你可以按照以下步骤进行: 1. 准备数据集:将你的数据集划分为训练集和验证集,并组织成 PyTorch 的 Dataset 类的形式。Dataset 类需要实现 `__len__()` 和 `__getitem__()` 方法,用于返回数据集大小和获取样本。 2. 数据预处理:根据你的任务需求,对图像进行必要的预处理操作,例如缩放、裁剪、归一化等。你可以使用 PyTorch 提供的图像处理工具包 torchvision 来方便地完成这些操作。 3. 定义网络模型:使用 PyTorch 构建 UNet 模型。你可以自己实现模型结构,也可以使用现有的开源实现。 4. 定义损失函数:根据你的任务类型,选择适当的损失函数。例如,对于图像分割任务,你可以使用交叉熵损失函数或 Dice Loss。 5. 定义优化器:选择合适的优化器来更新模型的参数。常用的优化器包括 Adam、SGD 等,你可以根据自己的需求进行选择。 6. 训练模型:使用 DataLoader 来加载数据,将数据输入到网络中进行训练。在每个 epoch 结束后,计算损失函数并进行反向传播更新模型参数。 7. 评估模型:使用验证集对训练的模型进行评估,计算预测结果的准确率、召回率、F1 值等指标。 8. 预测新数据:使用训练好的模型对新数据进行预测。将新数据输入到模型中,得到预测结果。 这些是基本的步骤,你可以根据自己的具体情况进行调整和扩展。希望这些对你有所帮助!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值