从零手写Unet代码及数据

在这里插入图片描述
数据集资源已上传,不需要积分(可能还在审核)。
下载链接

一.网络模型(model.py)

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(DoubleConv, self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels,out_channels,3,1,1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        return self.conv(x)



class Unet(nn.Module):
    def __init__(self,in_channels=3,out_channels=1,features=[64,128,256,512],):
        super(Unet,self).__init__()
        self.ups = nn.ModuleList()
        self.downs=nn.ModuleList()
        self.pool=nn.MaxPool2d(kernel_size=2,stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channels,feature))
            in_channels=feature


        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2,)
            )
            self.ups.append(DoubleConv(feature*2,feature))


        self.bottleneck=DoubleConv(features[-1],features[-1]*2)
        self.final_conv=nn.Conv2d(features[0],out_channels,kernel_size=1)


    def forward(self,x):
        skip_connections=[]

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x=self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0,len(self.ups),2):
            x=self.ups[idx](x)
            skip_connection=skip_connections[idx//2]

            if x.shape!=skip_connection.shape:
                x=TF.resize(x,size=skip_connection.shape[2:])

            concat_skip=torch.cat((skip_connection,x),dim=1)
            x=self.ups[idx+1](concat_skip)

        return self.final_conv(x)



def test():
    x=torch.randn((3,1,160,160))
    model=Unet(in_channels=1,out_channels=1)
    preds=model(x)
    print(preds.shape)



if __name__== "__main__":
    test()
    

二.数据(dataset.py)

import os
from  PIL import Image
from torch.utils.data import Dataset
import numpy as np

class CarvanaDataset(Dataset):
    def __init__(self,image_dir,mask_dir,transform=None):
        self.image_dir =image_dir
        self.mask_dir = mask_dir
        self.transform=transform
        self.images=os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path=os.path.join(self.image_dir,self.images[index])
        mask_path=os.path.join(self.mask_dir,self.images[index].replace(".jpg","_mask.gif"))
        image=np.array(Image.open(img_path).convert("RGB"))
        mask=np.array(Image.open(mask_path).convert("L"),dtype=np.float32)
        mask[mask==255.0]=1.0

        if self.transform is not None:
            augmentations=self.transform(image=image,mask=mask)
            image=augmentations["image"]
            mask = augmentations["mask"]

        return image,mask

三.训练

train.py

import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import Unet
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    save_predictions_as_imgs,
    check_accuracy)
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
NUM_EPOCHS =10
NUM_WORKERS = 2
IMAGE_HEIGHT =160
IMAGE_WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "data/train_images/"
TRAIN_MASK_DIR = "data/train_masks/"
VAL_IMG_DIR = "data/val_images/"
VAL_MASK_DIR = "data/val_masks/"

def train_fn(loader,model,optimizer,loss_fn,scaler):
    loop=tqdm(loader)
    for batch_idx,(data,targets) in enumerate(loop):
        data=data.to(device=DEVICE)
        targets=targets.float().unsqueeze(1).to(device=DEVICE)

        with torch.cuda.amp.autocast():
            predictions=model(data)
            loss=loss_fn(predictions,targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0], std=[1.0,1.0,1.0],max_pixel_value = 255.0,
            ),
            ToTensorV2(),
        ],
    )
    val_transform= A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0], std=[1.0,1.0,1.0],max_pixel_value = 255.0,
            ),
            ToTensorV2(),
        ],
    )

    model=Unet(in_channels=3,out_channels=1).to(DEVICE)
    loss_fn=nn.BCEWithLogitsLoss()
    optimizer=optim.Adam(model.parameters(),lr=LEARNING_RATE)
    train_loader,val_loader=get_loaders(

        TRAIN_IMG_DIR ,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR ,
        VAL_MASK_DIR ,
        BATCH_SIZE,
        train_transform,
        val_transform,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    ##if LOAD_MODEL:
    ##    load_checkpoint(torch.load("my_checkpoint.pth.tar"),model)


    ##check_accuracy(val_loader,model,device = DEVICE)

    scaler=torch.cuda.amp.GradScaler()
    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader,model,optimizer,loss_fn,scaler)

        # save model
        checkpoint = {
            "state_dict":model.state_dict(),
            "optimizer ":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)
        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE )
        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader,model,folder="saved_images/",device=DEVICE
        )




if __name__== "__main__":
    main()

utils.py

import torch
import torchvision
from dataset import CarvanaDataset
from torch.utils.data import DataLoader


def save_checkpoint(state,filename="my_checkpoint.pth.tar"):
    print( "=>Saving checkpoint")
    torch.save(state,filename)


def load_checkpoint(checkpoint,model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = CarvanaDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = CarvanaDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )
    val_loader=DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )
    return train_loader, val_loader


def check_accuracy(loader,model,device="cuda" ):
    num_correct = 0
    num_pixels = 0
    dice_score=0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            #print("------------------")
            #print(y)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            #print(preds.shape)
            #print(preds)
            num_correct += (preds == y).sum()
            #print(num_correct)
            num_pixels += torch.numel(preds)
            dice_score +=(2*(preds*y).sum())/(
                (preds+y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()


def save_predictions_as_imgs(
loader, model, folder="saved_images/" , device="cuda"
):
    model.eval()
    for idx,(x,y) in enumerate(loader):
        x= x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1),f"{folder}/{idx}.png")

    model.train()

四.预测(predict.py)

预测部分的图片是放在predict目录下,运行后输入 predict/test.jpg,预测后的图片会保存在相同的目录下,名称为testout.png。由于数据集单一,模型效果较差,适合练手。

import torch
import torchvision
import cv2
from PIL import Image
import numpy as np
from model import Unet
from utils import (
    load_checkpoint,
    )


if __name__== "__main__":
    model = Unet(in_channels=3, out_channels=1).to("cuda")
    model.eval()
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
    img = input('Input image filename:')
    name=img.replace(".jpg", "out")
    image=Image.open(img)
    orininal_h = np.array(image).shape[0]
    orininal_w = np.array(image).shape[1]
    image = np.array(Image.open(img).convert("RGB").resize((240,160), Image.BICUBIC))

    image=np.expand_dims(np.transpose(image, (2, 0, 1)),0)
    print(image.shape)
    image=torch.from_numpy(image)
    image = image.float()
    image /= 255.0
    image = image.cuda()
    with torch.no_grad():
        preds = torch.sigmoid(model(image))
        preds = (preds > 0.5).float()
        print(preds.shape)

    #注释的代码是将图像恢复到原来的尺寸
    #preds = preds.cpu().numpy()
    #preds=np.squeeze(preds,axis=(0,1))
    #preds = cv2.resize(preds, dsize=(orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR)
    #preds = torch.from_numpy(preds)
    torchvision.utils.save_image(
        preds, f"{name}.png"
    )


  • 3
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值