Pytorch搭建训练简单的图像分割模型

一、分割模型的搭建

先从最简单的模型搭建开始,输入图像大小是3x224x224,卷积部分使用的是VGG11模型,经过第5个maxpooling后开始上采样,经过5个反卷积层还原成原始图像大小。
model.py

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.encode1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.encode2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.encode3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.encode4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.encode5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.decode1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3,
                               stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True)
        )
        self.decode2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )
        self.decode3 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        self.decode4 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, 2, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(True)
        )
        self.decode5 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 2, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(True)
        )
        self.classifier = nn.Conv2d(16, 2, kernel_size=1)

    def forward(self, x):           # b: batch_size
        out = self.encode1(x)       # [b, 3, 224, 224]  =>  [b, 64, 112, 112]
        out = self.encode2(out)     # [b, 64, 112, 112] =>  [b, 128, 56, 56]
        out = self.encode3(out)     # [b, 128, 56, 56]  =>  [b, 256, 28, 28]
        out = self.encode4(out)     # [b, 256, 28, 28]  =>  [b, 512, 14, 14]
        out = self.encode5(out)     # [b, 512, 14, 14]  =>  [b, 512, 7, 7]
        out = self.decode1(out)     # [b, 512, 7, 7]    =>  [b, 256, 14, 14]
        out = self.decode2(out)     # [b, 256, 14, 14]  =>  [b, 128, 28, 28]
        out = self.decode3(out)     # [b, 128, 28, 28]  =>  [b, 64, 56, 56]
        out = self.decode4(out)     # [b, 64, 56, 56]   =>  [b, 32, 112, 112]
        out = self.decode5(out)     # [b, 32, 112, 112] =>  [b, 16, 224, 224]
        out = self.classifier(out)  # [b, 16, 224, 224] =>  [b, 2, 224, 224]   2表示类别数,目标和非目标两类
        return out


if __name__ == '__main__':
    img = torch.randn(2, 3, 224, 224)
    net = Net()
    sample = net(img)
    print(sample.shape)

二、数据读取

数据存放格式如下所示,图像放在last里,标签放在last_msk里。

├─data
    ├─test
    │  ├─last
    │  └─last_msk
    └─train
        ├─last
        └─last_msk

last:
在这里插入图片描述
last_msk:
在这里插入图片描述
load_img.py:

from torch.utils.data import Dataset
import os
import cv2
import numpy as np


class MyDataset(Dataset):
    def __init__(self, train_path, transform=None):
        self.images = os.listdir(train_path + '/last')
        self.labels = os.listdir(train_path + '/last_msk')
        assert len(self.images) == len(self.labels), 'Number does not match'
        self.transform = transform
        self.images_and_labels = []    # 存储图像和标签路径
        for i in range(len(self.images)):
            self.images_and_labels.append((train_path + '/last/' + self.images[i], train_path + '/last_msk/' + self.labels[i]))

    def __getitem__(self, item):
        img_path, lab_path = self.images_and_labels[item]
        img = cv2.imread(img_path)
        img = cv2.resize(img, (224, 224))
        lab = cv2.imread(lab_path, 0)
        lab = cv2.resize(lab, (224, 224))
        lab = lab / 255    # 转换成0和1
        lab = lab.astype('uint8')    # 不为1的全置为0
        lab = np.eye(2)[lab]    # one-hot编码
        lab = np.array(list(map(lambda x: abs(x-1), lab))).astype('float32')   # 将所有0变为1(1对应255, 白色背景),所有1变为0(黑色,目标)
        lab = lab.transpose(2, 0, 1)  # [224, 224, 2] => [2, 224, 224]
        if self.transform is not None:
            img = self.transform(img)
        return img, lab

    def __len__(self):
        return len(self.images)


if __name__ == '__main__':
    img = cv2.imread('data/train/last_msk/150.jpg', 0)
    img = cv2.resize(img, (16, 16))
    img2 = img/255
    img3 = img2.astype('uint8')
    hot1 = np.eye(2)[img3]
    hot2 = np.array(list(map(lambda x: abs(x-1), hot1)))
    print(hot2.shape)
    print(hot2.transpose(2, 0, 1))

直接运行load_img.py可查看编码后的一张标签图像矩阵。
150.jpg:
在这里插入图片描述

(16, 16, 2)
[[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1.]
  [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
  [1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
  [1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
  [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
  [0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0.]
  [0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]
  [0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]

三、训练

train.py:

import os
import model
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from load_img import MyDataset
from torchvision import transforms
from torch.utils.data import DataLoader


batchsize = 8
epochs = 50
train_data_path = 'data/train'

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
bag = MyDataset(train_data_path, transform)
dataloader = DataLoader(bag, batch_size=batchsize, shuffle=True)


device = torch.device('cuda')
net = model.Net().to(device)
criterion = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr=1e-2, momentum=0.7)

if not os.path.exists('checkpoints'):
    os.mkdir('checkpoints')

for epoch in range(1, epochs+1):
    for batch_idx, (img, lab) in enumerate(dataloader):
        img, lab = img.to(device), lab.to(device)
        output = torch.sigmoid(net(img))
        loss = criterion(output, lab)

        output_np = output.cpu().data.numpy().copy()
        output_np = np.argmin(output_np, axis=1)
        y_np = lab.cpu().data.numpy().copy()
        y_np = np.argmin(y_np, axis=1)

        if batch_idx % 20 == 0:
            print('Epoch:[{}/{}]\tStep:[{}/{}]\tLoss:{:.6f}'.format(
                epoch, epochs, (batch_idx+1)*len(img), len(dataloader.dataset), loss.item()
            ))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if epoch % 10 == 0:
        torch.save(net, 'checkpoints/model_epoch_{}.pth'.format(epoch))
        print('checkpoints/model_epoch_{}.pth saved!'.format(epoch))
Epoch:[1/50]	Step:[8/499]	Loss:0.702611
Epoch:[1/50]	Step:[168/499]	Loss:0.697093
Epoch:[1/50]	Step:[328/499]	Loss:0.686626
Epoch:[1/50]	Step:[488/499]	Loss:0.676049
Epoch:[2/50]	Step:[8/499]	Loss:0.667989
Epoch:[2/50]	Step:[168/499]	Loss:0.664439
Epoch:[2/50]	Step:[328/499]	Loss:0.638619
Epoch:[2/50]	Step:[488/499]	Loss:0.636599
Epoch:[3/50]	Step:[8/499]	Loss:0.616667

四、测试图片

import torch
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import os


class TestDataset(Dataset):
    def __init__(self, test_img_path, transform=None):
        self.test_img = os.listdir(test_img_path)
        self.transform = transform
        self.images = []
        for i in range(len(self.test_img)):
            self.images.append(os.path.join(test_img_path, self.test_img[i]))

    def __getitem__(self, item):
        img_path = self.images[item]
        img = cv2.imread(img_path)
        img = cv2.resize(img, (224, 224))
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.test_img)


test_img_path = 'data/test/last'
checkpoint_path = 'checkpoints/model_epoch_50.pth'
save_dir = 'result'
if not os.path.exists(save_dir ):
    os.mkdir(save_dir )

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
bag = TestDataset(test_img_path, transform)
dataloader = DataLoader(bag, batch_size=1, shuffle=None)

net = torch.load(checkpoint_path)
net = net.cuda()
for idx, img in enumerate(dataloader):
    img = img.cuda()
    output = torch.sigmoid(net(img))

    output_np = output.cpu().data.numpy().copy()
    output_np = np.argmin(output_np, axis=1)

    img_arr = np.squeeze(output_np)
    img_arr = img_arr*255
    cv2.imwrite('%s/%03d.png'%(save_dir, idx), img_arr)
    print('%s/%03d.png'%(save_dir, idx))

全部代码:github

  • 11
    点赞
  • 74
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值