图像分割:Unet的pytorch代码实现(一)如何使用自己的数据集

深度学习基础:如何使用自己的数据集

通过继承torch的Dataset类,来实现加载自己的数据集。

本文以ISIC2018数据集(这个是开源的数据集)为例。

需要重写的三个函数

import torch


# 数据加载器
class Reader(torch.utils.data.Dataset):  # 数据读取
    """
    读取数据
    """

    def __init__(self):
        super().__init__()
        # 这里可以进行一些数据的预处理,比如类型转换、数据增强等。
        # 一般都在在这里生成两列表,一个是所有输入数据的列表,另一个是所有标签的列表,它们是一一对应的
        pass

    def __getitem__(self, item):
        pass
        # 这里需要返回数据集对应item的两Tenser类型的数据,一个是输入数据,另一个是标签数据

    def __len__(self):
        # 返回数据集长度
        pass

使用自己的数据集

这里是图像预处理部分。根据需要对数据进行处理,实现数据增强的效果。

def get_new_data(img, label):
    """
    数据预处理:把图像矩阵填充为方阵。
    :param img: Image类型,输入特征图
    :param label: Image类型,单通道
    :return:
    """
    try:
        # 这个异常要不要都行,防止有一些图像处理不了程序直接全部异常退出的。
        img_arr = np.array(img)
        max_size = max(img.size)
        # 填充图像矩阵成为(max_len, max_len)
        img_arr = np.pad(img_arr, ((0, max_size - img_arr.shape[0]), (0, max_size - img_arr.shape[1]), (0, 0)),
                         'constant', constant_values=255)
        new_img = Image.fromarray(img_arr).convert('RGB')

        label_arr = np.array(label)
        label_arr = np.pad(label_arr, ((0, max_size - label_arr.shape[0]), (0, max_size - label_arr.shape[1])),
                           'constant', constant_values=0)
        new_label = Image.fromarray(label_arr).convert('L')
        return new_img, new_label
    except:
        return img, label

Dataset类 

import os

import numpy as np
import torch
from PIL import Image


# 数据集读取
class Reader(torch.utils.data.Dataset):  # 数据读取
    """
    读取数据
    """

    def __init__(self, images_path, labels_path, transform):
        """

        :param images_path: 图片地址
        :param labels_path: 标签地址
        :param transform: 类型转换对象
        """
        super().__init__()
        # 获取数据列表

        # transform:将类型转换为Tenser类型。
        self.transform = transform
        # 用于保存原始图像信息,一般保存图像名称,需要用时根据名称打开图像。
        datas = []

        labels_list = os.listdir(labels_path)  # 生成路径
        # 读取图像原始信息,这里为图像地址
        for i in labels_list:  # 遍历图片存放的目录
            if '.png' in i:
                # 我使用的数据集标签和原图的名称是基本一样的,只是目录不同
                label_path = os.path.join(labels_path, i)
                image_path = os.path.join(images_path, i.split('_')[0] + '_' + i.split('_')[1] + '.jpg')
                # print(label_path, image_path)
                datas.append((image_path, label_path))
        print("共{}个数据".format(datas.__len__()))
        self.datas = datas

    def __getitem__(self, item):
        img_path, label_path = self.datas[item]
        # 根据需要读取RGB图或者灰度图
        img = Image.open(img_path)
        label = Image.open(label_path).convert('L')
        # 数据预处理
        img, label = get_new_data(img, label)
        # 将图片从Image类型转为Tensor类型
        img = self.transform(img)
        label = self.transform(label)
        # 返回数据集对应item位置的图像和标签,都是Tensor类型
        return img, label

    def __len__(self):
        # 返回数据集长度
        return self.datas.__len__()

这样我们就设置好了一个Dataset。由于数据集格式不同,Dataset的数据读取方式也不同,需要灵活运用。 

构造一个数据加载器

 

if __name__ == '__main__':
    import torchvision
    from torch.utils.data import DataLoader


    transform = torchvision.transforms.Compose(
        [torchvision.transforms.Resize((64, 64)),   # 缩放
         torchvision.transforms.ToTensor()])        # 类型转换

    images_path = r"E:\数据集\1-2_Validation_Input\ISIC2018_Task1-2_Validation_Input"
    label_path = r"E:\数据集\1_Validation_GroundTruth\ISIC2018_Task1_Validation_GroundTruth"


    dateset = Reader(images_path, label_path, transform)

    train_DataLoader = DataLoader(  # 数据加载器
        dataset=dateset,    # 选择dataset
        batch_size=3,       # 选择batch大小
        num_workers=0,      # windows得设置这个,不然有时候会报错
        shuffle=True)       # 是否乱序

查看是否正常读取数据

    for data in train_DataLoader:
        img, label = data
        print(img.shape, label.shape)

输出:

torch.Size([3, 3, 64, 64]) torch.Size([3, 1, 64, 64])

每个维度的含义[batch_size, 图像通道数, img_x_size, img_y_size]

batch_size:批次大小

图像通道数:比如RGB图像是三通道图像,灰度图是单通道图像

好的,下面是一个使用PyTorch实现的简单的UNet模型。 首先,需要安装PyTorch库。可以使用以下命令进行安装: ``` pip install torch ``` 然后,我们可以定义一个UNet模型。 ```python import torch import torch.nn as nn class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class Up(nn.Module): def __init__(self, in_channels, out_channels): super(Up, self).__init__() self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) diffX = x2.size()[2] - x1.size()[2] diffY = x2.size()[3] - x1.size()[3] x1 = nn.functional.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) x = torch.cat([x2, x1], dim=1) return self.conv(x) class UNet(nn.Module): def __init__(self, n_channels, n_classes): super(UNet, self).__init__() self.inc = DoubleConv(n_channels, 64) self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128)) self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256)) self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(256, 512)) self.down4 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(512, 1024)) self.up1 = Up(1024, 512) self.up2 = Up(512, 256) self.up3 = Up(256, 128) self.up4 = Up(128, 64) self.outc = nn.Conv2d(64, n_classes, kernel_size=1) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) return self.outc(x) ``` 在这个模型中,我们首先定义了一个双卷积层(DoubleConv)和一个上采样层(Up),然后将它们组合起来构建了一个UNet模型。UNet模型用于图像分割,将输入图像分割成多个部分,每个部分都对应着一个特定的标签。UNet模型的结构类似于自编码器,由一个下采样器和一个上采样器组成。下采样器用于提取特征,上采样器用于将特征图恢复到原始图像大小,并将特征图与下采样器对应的特征图进行特征融合。 接下来,我们可以定义一个函数来训练这个模型。 ```python def train(model, train_loader, val_loader, criterion, optimizer, n_epochs=10): for epoch in range(n_epochs): train_loss = 0 val_loss = 0 model.train() for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() train_loss += loss.item() model.eval() with torch.no_grad(): for batch_idx, (data, target) in enumerate(val_loader): output = model(data) loss = criterion(output, target) val_loss += loss.item() train_loss /= len(train_loader.dataset) val_loss /= len(val_loader.dataset) print('Epoch: {} Train Loss: {:.6f} Val Loss: {:.6f}'.format( epoch + 1, train_loss, val_loss)) ``` 在训练函数中,我们首先循环训练数据集,计算损失并更新模型参数。然后我们循环验证数据集,计算损失并输出训练和验证损失。 接下来,我们可以定义一个函数来测试这个模型。 ```python def test(model, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) test_loss += criterion(output, target).item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('Test Loss: {:.6f} Test Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) ``` 在测试函数中,我们首先将模型设置为评估模式,然后循环测试数据集,计算损失并输出测试精度。 最后,我们可以定义一个函数来进行训练和测试的循环。 ```python def train_and_test(model, train_loader, val_loader, test_loader, criterion, optimizer, n_epochs=10): for epoch in range(n_epochs): train(model, train_loader, val_loader, criterion, optimizer, n_epochs) test(model, test_loader) model = UNet(n_channels=3, n_classes=2) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() train_and_test(model, train_loader, val_loader, test_loader, criterion, optimizer, n_epochs=10) ``` 在这个函数中,我们首先定义了一些超参数,包括训练轮数、优化器和损失函数。然后我们循环训练和测试模型,并在每个epoch结束后输出测试结果。 这就是一个简单的基于PyTorchUNet模型。当然,这里只是给出了一个简单的实现,还可以进行更多的优化和改进,例如使用更复杂的模型、使用预训练模型等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值