使用Pytorch搭建U-Net网络

原因

github上关于Unet网络的实现不少,其中milesial实现了基于pytorch的,但是,在运行过程中,发现其代码训练很慢,而且特别占内存,在显存为12G的3060上的batch_szie也只能为2。故另寻其方法,好在b站博主霹雳吧啦Wz实现了pytorch的简化版本,这里我推荐一下这位博主,很适合初学者。

搭建U-Net

根据其在github上的readme,搭建好环境,只需要改变my_dataset.py文件即可运行自己的数据集。
my_dataset.py更改如下:

import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset


class DriveDataset(Dataset):
    def __init__(self, root: str, train: bool, transforms=None):
        super(DriveDataset, self).__init__()
        self.flag = "training" if train else "test"
        data_root = os.path.join(root, "DRIVE", self.flag)
        assert os.path.exists(data_root), f"path '{data_root}' does not exists."
        self.transforms = transforms
        img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".jpg")]
        self.img_list = [os.path.join(data_root, "images", i) for i in img_names]
        mask_names = [i for i in os.listdir(os.path.join(data_root, "mask")) if i.endswith(".png")]
        self.mask_list = [os.path.join(data_root, "mask", i) for i in mask_names]

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx]).convert('RGB')
        mask = Image.open(self.mask_list[idx])

        if self.transforms is not None:
            img, mask = self.transforms(img, mask)

        return img, mask

    def __len__(self):
        return len(self.img_list)

    @staticmethod
    def collate_fn(batch):
        images, targets = list(zip(*batch))
        batched_imgs = cat_list(images, fill_value=0)
        batched_targets = cat_list(targets, fill_value=255)
        return batched_imgs, batched_targets


def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
    batch_shape = (len(images),) + max_size
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
    for img, pad_img in zip(images, batched_imgs):
        pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
    return batched_imgs

数据集的格式

training和test分别存放训练和验证的数据集,images存放jpg格式格式的图片,mask存放png格式的图片,如果有其他格式的请在 i.endswith(“.jpg”)更改,将".jpg:和".png"改为相应的格式。
在这里插入图片描述

运行

python train.py

参考连接

链接: 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)
链接: deep-learning-for-image-processing/pytorch_segmentation/unet/
链接: milesial/Pytorch-UNet

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值