UNet - 数据加载 Dataset

目录

1. 介绍

2. 数据处理 dataset

2.1 预处理

2.2 加载数据

2.2.1 初始化

2.2.2 返回数据

2.2.3 样本数量

3. 测试一下

4. 完整代码


1. 介绍

之前介绍完了Unet网络的搭建,接下来说一下要解决的任务。

本章介绍的是:数据的加载处理

下面是整个项目:

  • data里面存放的是训练的数据
  • predict 是存放的是需要预测的数据
  • result 是predict里面预测出来的结果
  • dataset 是数据加载的文件、model 是UNet网络、train是训练、predict是预测

本项目参考这篇文章:UNet模型训练,深度解析! ,网络做了一些优化和更改,整个项目完成会上传到CSDN,数据可以在链接里面获取

因为data数据只有30张,并且没有test集,所以这里手工分类了一下。将对应的image和label取出来放到test里面即可,这里21张用于train,9张用于test

 

样本图片:

 

对应label:

 

2. 数据处理 dataset

有关内容可以参考:关于pytorch的数据处理-数据加载Dataset

2.1 预处理

因为UNet 网络,我们希望的输入是480*480的灰度图,所以预处理的时候要改变一个size

图像本身就是灰度图,所以这里不需要转换

最后要将图像转为Tensor 

这里没有用数据增强:翻转、随即裁剪等等。因为这里不确定随机的翻转对image和label是否是一致的。

这里可以通过设置字典,对image,进行normalization

2.2 加载数据

观察下目录结构,后面用得到

 

2.2.1 初始化

这里如果定义加载类的话,需要继承  from torch.utils.data import Dataset 里面的Dataset

初始化init 方法里面实现的是初始化相关的操作,例如指定文件的路径和预处理等等

这里root指定要处理数据的目录,这里指定的是train里面的image

imgs 只会读取里面每个文件

 想要获得image下具体图片的路径就要将root + imgs ,也就是self.imgs

 

2.2.2 返回数据

getitem 是返回一个样本,那么既然这个方法返回的就是我们需要的每个样本,那么读取每个图像,甚至对图像操作都应该在getitem里面

首先,self.imgs 是个列表,里面存放的是整个训练图片的路径。根据index索引获取每个图片,

因为train和test里面的图像和标签都是相同的文件名,观察每个图片的路径,只需要将train替换成label就可以获取图像对应的标签图像了

 

 


 通过上面的open获取每个对应的图片和图片的label


 这里就是简单的预处理

需要注意的是,因为这里的label不是二值图片,所以需要转换一下。因为预处理的ToTensor会将像素 / 255 变成0-1之间,所以这里将大于等于0.5的设置为1,小于0.5的设置为0

最后返回image和label就行了

2.2.3 样本数量

 

3. 测试一下

image:

label:

4. 完整代码

code:

import os
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image


transform = transforms.Compose([
    transforms.Resize((480,480)),        # 缩放图像
    transforms.ToTensor(),               # 转为Tensor
])


# 数据处理文件
class Data_Loader(Dataset):          # 加载数据
    def __init__(self, root, transforms = transform):               # 指定路径、预处理等等
        imgs = os.listdir(root)                                     # 获取root文件下的文件
        self.imgs = [os.path.join(root,img) for img in imgs]        # 获取每个文件的路径
        self.transforms = transforms                                # 预处理

    def __getitem__(self, index):    # 读取图片,返回一条样本
        image_path = self.imgs[index]                       # 根据index读取图片
        label_path = image_path.replace('image', 'label')   # 把路径中的image替换成label,就找到对应数据的label

        image = Image.open(image_path)                      # 读取图片和对应的label图
        label = Image.open(label_path)

        if self.transforms:                                 # 判断是否预处理
            image = self.transforms(image)

            label = self.transforms(label)
            label[label>=0.5] = 1               # 这里转为二值图片
            label[label< 0.5] = 0

        return image, label

    def __len__(self):  # 返回样本的数量
        return len(self.imgs)


# if __name__ == "__main__":
# 
#     dataset = Data_Loader("./data/test/image")               # 加载数据
# 
#     for image,label in dataset:
#         print(image)
#         print('image size:',image.size())   # image size: torch.Size([1, 480, 480])
#         print(label)
#         print('label size:',label.size())   # label size: torch.Size([1, 480, 480])
#         break

  • 7
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
对于使用 PyTorch 训练自己的数据集,你可以按照以下步骤进行: 1. 准备数据集:将你的数据集划分为训练集和验证集,并组织成 PyTorch 的 Dataset 类的形式。Dataset 类需要实现 `__len__()` 和 `__getitem__()` 方法,用于返回数据集大小和获取样本。 2. 数据预处理:根据你的任务需求,对图像进行必要的预处理操作,例如缩放、裁剪、归一化等。你可以使用 PyTorch 提供的图像处理工具包 torchvision 来方便地完成这些操作。 3. 定义网络模型:使用 PyTorch 构建 UNet 模型。你可以自己实现模型结构,也可以使用现有的开源实现。 4. 定义损失函数:根据你的任务类型,选择适当的损失函数。例如,对于图像分割任务,你可以使用交叉熵损失函数或 Dice Loss。 5. 定义优化器:选择合适的优化器来更新模型的参数。常用的优化器包括 Adam、SGD 等,你可以根据自己的需求进行选择。 6. 训练模型:使用 DataLoader 来加载数据,将数据输入到网络中进行训练。在每个 epoch 结束后,计算损失函数并进行反向传播更新模型参数。 7. 评估模型:使用验证集对训练的模型进行评估,计算预测结果的准确率、召回率、F1 值等指标。 8. 预测新数据:使用训练好的模型对新数据进行预测。将新数据输入到模型中,得到预测结果。 这些是基本的步骤,你可以根据自己的具体情况进行调整和扩展。希望这些对你有所帮助!
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

听风吹等浪起

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值