Pytorch---使用Pytorch实现LinkNet进行语义分割

一、代码中的数据集可以通过以下链接获取

百度网盘提取码:f1j7

二、代码运行环境

Pytorch-gpu==1.10.1
Python==3.8

三、数据集处理代码如下所示

import os
import torch
from torch.utils import data
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks


class MaskDataset(data.Dataset):
    def __init__(self, image_paths, mask_paths, transform):
        super(MaskDataset, self).__init__()
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label_path = self.mask_paths[index]

        pil_img = Image.open(image_path)
        pil_img = pil_img.convert('RGB')
        img_tensor = self.transform(pil_img)

        pil_label = Image.open(label_path)
        label_tensor = self.transform(pil_label)
        label_tensor[label_tensor > 0] = 1
        label_tensor = torch.squeeze(input=label_tensor).type(torch.LongTensor)

        return img_tensor, label_tensor

    def __len__(self):
        return len(self.mask_paths)


def load_data():
    # DATASET_PATH = r'/home/akita/hk'
    DATASET_PATH = r'/Users/leeakita/Desktop/hk'
    TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, 'training')
    TEST_DATASET_PATH = os.path.join(DATASET_PATH, 'testing')

    train_file_names = os.listdir(TRAIN_DATASET_PATH)
    test_file_names = os.listdir(TEST_DATASET_PATH)

    train_image_names = [name for name in train_file_names if
                         'matte' in name and name.split('_')[0] + '.png' in train_file_names]
    train_image_paths = [os.path.join(TRAIN_DATASET_PATH, name.split('_')[0] + '.png') for name in
                         train_image_names]
    train_label_paths = [os.path.join(TRAIN_DATASET_PATH, name) for name in train_image_names]

    test_image_names = [name for name in test_file_names if
                        'matte' in name and name.split('_')[0] + '.png' in test_file_names]
    test_image_paths = [os.path.join(TEST_DATASET_PATH, name.split('_')[0] + '.png') for name in test_image_names]
    test_label_paths = [os.path.join(TEST_DATASET_PATH, name) for name in test_image_names]

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    BATCH_SIZE = 8

    train_ds = MaskDataset(image_paths=train_image_paths, mask_paths=train_label_paths, transform=transform)
    test_ds = MaskDataset(image_paths=test_image_paths, mask_paths=test_label_paths, transform=transform)

    train_dl = data.DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_dl = data.DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

    return train_dl, test_dl


if __name__ == '__main__':
    train_my, test_my = load_data()
    images, labels = next(iter(train_my))
    indexx = 5
    images = images[indexx]
    labels = labels[indexx]
    labels = torch.unsqueeze(input=labels, dim=0)

    result = draw_segmentation_masks(image=torch.as_tensor(data=images * 255, dtype=torch.uint8),
                                     masks=torch.as_tensor(data=labels, dtype=torch.bool),
                                     alpha=0.6, colors=['red'])
    plt.imshow(result.permute(1, 2, 0).numpy())
    plt.show()

四、模型的构建代码如下所示

from torch import nn
import torch


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv_bn_relu = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                      padding=padding),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv_bn_relu(x)


class DecodeConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, out_padding=1):
        super(DecodeConvBlock, self).__init__()
        self.de_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                          stride=stride, padding=padding, output_padding=out_padding)
        self.bn = nn.BatchNorm2d(num_features=out_channels)

    def forward(self, x, is_act=True):
        x = self.de_conv(x)
        if is_act:
            x = torch.relu(self.bn(x))
        return x


class EncodeBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncodeBlock, self).__init__()
        self.conv1 = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)
        self.conv2 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
        self.conv3 = ConvBlock(in_channels=out_channels, out_channels=out_channels)
        self.conv4 = ConvBlock(in_channels=out_channels, out_channels=out_channels)

        self.short_cut = ConvBlock(in_channels=in_channels, out_channels=out_channels, stride=2)

    def forward(self, x):
        out1 = self.conv1(x)
        out1 = self.conv2(out1)

        short_cut = self.short_cut(x)

        out2 = self.conv3(out1 + short_cut)
        out2 = self.conv4(out2)

        return out1 + out2


class DecodeBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecodeBlock, self).__init__()
        self.conv1 = ConvBlock(in_channels=in_channels, out_channels=in_channels // 4, kernel_size=1, padding=0)
        self.de_conv = DecodeConvBlock(in_channels=in_channels // 4, out_channels=in_channels // 4)
        self.conv3 = ConvBlock(in_channels=in_channels // 4, out_channels=out_channels, kernel_size=1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.de_conv(x)
        x = self.conv3(x)
        return x


class LinkNet(nn.Module):
    def __init__(self):
        super(LinkNet, self).__init__()
        self.init_conv = ConvBlock(in_channels=3, out_channels=64, stride=2, kernel_size=7, padding=3)
        self.init_maxpool = nn.MaxPool2d(kernel_size=(2, 2))

        self.encode_1 = EncodeBlock(in_channels=64, out_channels=64)
        self.encode_2 = EncodeBlock(in_channels=64, out_channels=128)
        self.encode_3 = EncodeBlock(in_channels=128, out_channels=256)
        self.encode_4 = EncodeBlock(in_channels=256, out_channels=512)

        self.decode_4 = DecodeBlock(in_channels=512, out_channels=256)
        self.decode_3 = DecodeBlock(in_channels=256, out_channels=128)
        self.decode_2 = DecodeBlock(in_channels=128, out_channels=64)
        self.decode_1 = DecodeBlock(in_channels=64, out_channels=64)

        self.deconv_out1 = DecodeConvBlock(in_channels=64, out_channels=32)
        self.conv_out = ConvBlock(in_channels=32, out_channels=32)
        self.deconv_out2 = DecodeConvBlock(in_channels=32, out_channels=2, kernel_size=2, padding=0, out_padding=0)

    def forward(self, x):
        x = self.init_conv(x)
        x = self.init_maxpool(x)

        e1 = self.encode_1(x)
        e2 = self.encode_2(e1)
        e3 = self.encode_3(e2)
        e4 = self.encode_4(e3)

        d4 = self.decode_4(e4)
        d3 = self.decode_3(d4 + e3)
        d2 = self.decode_2(d3 + e2)
        d1 = self.decode_1(d2 + e1)

        f1 = self.deconv_out1(d1)
        f2 = self.conv_out(f1)
        f3 = self.deconv_out2(f2)
        return f3

五、模型的训练代码如下所示

import torch
from data_loader import load_data
from model_loader import LinkNet
from torch import nn
from torch import optim
import tqdm
import os

# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载数据
train_dl, test_dl = load_data()

# 加载模型
model = LinkNet()
model = model.to(device=device)

# 训练的相关配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.7)

# 开始进行训练
for epoch in range(100):
    train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
    train_tqdm.set_description_str('Train epoch: {:3d}'.format(epoch))
    train_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
    train_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
    for train_images, train_labels in train_tqdm:
        train_images, train_labels = train_images.to(device), train_labels.to(device)
        pred = model(train_images)
        loss = loss_fn(pred, train_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            intersection = torch.logical_and(input=train_labels, other=torch.argmax(input=pred, dim=1))
            union = torch.logical_or(input=train_labels, other=torch.argmax(input=pred, dim=1))
            batch_iou = torch.true_divide(torch.sum(intersection), torch.sum(union))

            train_iou_sum = torch.cat([train_iou_sum, torch.unsqueeze(input=batch_iou, dim=-1)], dim=-1)
            train_loss_sum = torch.cat([train_loss_sum, torch.unsqueeze(input=loss, dim=-1)], dim=-1)
            train_tqdm.set_postfix({
                'train loss': train_loss_sum.mean().item(),
                'train iou': train_iou_sum.mean().item()
            })
    train_tqdm.close()

    lr_scheduler.step()

    with torch.no_grad():
        test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
        test_tqdm.set_description_str('Test epoch: {:3d}'.format(epoch))
        test_loss_sum = torch.tensor(data=[], dtype=torch.float, device=device)
        test_iou_sum = torch.tensor(data=[], dtype=torch.float, device=device)
        for test_images, test_labels in test_tqdm:
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            test_pred = model(test_images)
            test_loss = loss_fn(test_pred.softmax(dim=1), test_labels)

            test_intersection = torch.logical_and(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
            test_union = torch.logical_or(input=test_labels, other=torch.argmax(input=test_pred, dim=1))
            test_batch_iou = torch.true_divide(torch.sum(test_intersection), torch.sum(test_union))

            test_iou_sum = torch.cat([test_iou_sum, torch.unsqueeze(input=test_batch_iou, dim=-1)], dim=-1)
            test_loss_sum = torch.cat([test_loss_sum, torch.unsqueeze(input=test_loss, dim=-1)], dim=-1)
            test_tqdm.set_postfix({
                'test loss': test_loss_sum.mean().item(),
                'test iou': test_iou_sum.mean().item()
            })
        test_tqdm.close()

# 模型的保存
if not os.path.exists(os.path.join('model_data')):
    os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))

六、模型的预测代码如下所示

import torch
import os
import matplotlib.pyplot as plt
from torchvision.utils import draw_segmentation_masks
from data_loader import load_data
from model_loader import LinkNet

# 数据的加载
train_dl, test_dl = load_data()

# 模型的加载
model = LinkNet()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)

# 开始进行预测
images, labels = next(iter(test_dl))
index = 2
with torch.no_grad():
    pred = model(images)
    pred = torch.argmax(input=pred, dim=1)
    result = draw_segmentation_masks(image=torch.as_tensor(data=images[index] * 255, dtype=torch.uint8),
                                     masks=torch.as_tensor(data=pred[index], dtype=torch.bool),
                                     alpha=0.8, colors=['red'])
    plt.figure(figsize=(8, 8), dpi=500)
    plt.axis('off')
    plt.imshow(result.permute(1, 2, 0))
    plt.savefig('result.png')
    plt.show()

七、代码的运行结果如下所示

在这里插入图片描述

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

水哥很水

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值