利用小型数据集m2nist进行语义分割——(三)代码编写及训练与预测

利用小型数据集m2nist进行语义分割——(三)代码编写及训练与预测

微信公众号:幼儿园的学霸

目录

前言

接下来按照上一篇的神经网络框架,进行具体的代码编写。属于对构思的具体实现,相对还是比较容易的。 该文贴出了部分代码,完整代码地址:https://github.com/leonardohaig/m2nist-segmentation

代码编写

代码不多,非常简单。甚至一个.py文件都可以完成,为了清晰,我将其进行了划分。

编写完毕后代码文件夹内容如下所示:
代码结构

数据加载模块

数据加载模块加载m2nist数据集,并对图片和标签进行处理:
1)图片和标签的尺寸缩放,将尺寸从(64,84)填充到(64,96),不建议采用resize操作,感兴趣的可以看下区别;
2)输入图像归一化到0~1区间,以及通道的变换,加载后的图像其shape顺序为[B,H,W],需要将其变换为[B,C,H,W]的顺序;
3)将numpy类型的数据转换为tensor格式。

具体到代码编写过程,需要采用pytorch中的DataSet和DataLoader模块,由于数据集非常小,因此一次性全部读入内存,代码如下:

在向训练模块提供数据时,线程的数量是根据电脑cpu的数量来的。

#!/usr/bin/env python3
# coding=utf-8

# ============================#
# Program:m2nistDataSet.py
#       数据加载模块
# Date:20-4-16
# Author:liheng
# Version:V1.0
# ============================#

import os
import sys
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import utils_torch
from multiprocessing import cpu_count

__all__ = ['m2nistDataLoader']


class m2nistDatase(Dataset):
    """

    """

    def __init__(self, imgs_pth, masks_pth):
        assert os.path.isfile(imgs_pth)
        assert os.path.isfile(masks_pth)

        # load
        imgs = np.load(imgs_pth)
        masks = np.load(masks_pth)

        # padding
        # 从[64,84]填充大小到[64,96],
        # 对于图像采用0填充;对于label采用10常值填充,因为10代表背景
        imgs = np.pad(imgs, ((0, 0), (0, 0), (6, 6)), 'constant', constant_values=0)
        masks = np.pad(masks, ((0, 0), (0, 0), (6, 6)), 'constant', constant_values=10)

        self.imgs = np.expand_dims(imgs.astype(np.float32) / 255, axis=1)  # [B,C,H,W]
        self.masks = masks.astype(np.uint8)

    def __getitem__(self, index):
        img = torch.tensor(self.imgs[index])
        mask = torch.tensor(self.masks[index])

        return img, mask

    def __len__(self):
        return self.imgs.shape[0]


def m2nistDataLoader(cfg_pth, dataset_type='train'):
    """

    :param cfg_pth:
    :param dataset_type: train  or val (验证集validation)
    :return:
    """

    assert os.path.isfile(cfg_pth), 'config file does not exist !'
    config = utils_torch.get_config(cfg_pth)

    if dataset_type == 'train':
        imgs_pth = config['Train.images_pth']
        masks_pth = config['Train.masks_pth']
        batch_size = config['Train.batch_size']
    else:
        imgs_pth = config['Val.images_pth']
        masks_pth = config['Val.masks_pth']
        batch_size = config['Val.batch_size']

    dataset = m2nistDatase(imgs_pth, masks_pth)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=cpu_count() // 2)

    return dataloader


if __name__ == '__main__':
    os.chdir(os.path.split(os.path.abspath(__file__))[0])

    sys.path.append('..')
    import down_data

    if 0:
        data_rootdir = os.path.join(os.path.split(os.path.realpath(__file__))[0], '../', 'm2nist')
        imgs_pth = os.path.join(data_rootdir, 'train_imgs.npy')
        masks_pth = os.path.join(data_rootdir, 'train_masks.npy')

        dataset = m2nistDatase(imgs_pth, masks_pth)
        dataloader = DataLoader(dataset=dataset, batch_size=6,
                                shuffle=True, num_workers=cpu_count() // 2)
    else:
        dataloader = m2nistDataLoader('./config.yaml')

    for i, img_mask in enumerate(dataloader):
        img = np.squeeze(img_mask[0][0].numpy())
        down_data.show_img_mask(img, img_mask[1][0].numpy())

网络实现模块

网络实现模块按照上一篇文章中的结构进行网络的复现。同时我将损失函数也放在了该模块中。
代码如下:

#!/usr/bin/env python3
#coding=utf-8

#============================#
#Program:Model.py
#       
#Date:20-4-16
#Author:liheng
#Version:V1.0
#============================#

import layers
import torch
import torch.nn as nn
import torch.nn.functional as F


class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        # encoder
        self.encoder1 = self.EncoderBlock(1,16)
        self.encoder2 = self.EncoderBlock(16,32)
        self.encoder3 = self.EncoderBlock(32,64)
        self.encoder4 = self.EncoderBlock(64,96)

        # decoder
        self.decode4 = self.DecodeBlock(96,64)
        self.decode3 = self.DecodeBlock(128,32)
        self.decode2 = self.DecodeBlock(64,16)
        self.decode1 = self.DecodeBlock(32,16)

        self.res_conv = layers.conv2d(16,11,3,1)

    def EncoderBlock(self,in_channels, out_channels, t=6):
        return torch.nn.Sequential(layers.sepconv2d(in_channels,out_channels,3,2,False),
                                   layers.InvertedResidual(out_channels,out_channels,t=t,s=1))
    def DecodeBlock(self,in_channels,out_channels,kernel_size=3,bias=True):
        """

        :param in_channels:
        :param out_channels:
        :param kernel_size:
        :param bias:
        :return:
        """
        return torch.nn.Sequential(
            # conv1x1
            nn.Conv2d(in_channels,in_channels//4,1,bias=bias),
            nn.BatchNorm2d(in_channels//4),
            nn.ReLU6(),

            #deconv 3X3
            nn.ConvTranspose2d(in_channels//4,in_channels//4,kernel_size,stride=2,padding=1,output_padding=1,bias=bias),
            nn.BatchNorm2d(in_channels//4),
            nn.ReLU6(),

            # conv1x1
            nn.Conv2d(in_channels//4,out_channels,1,bias=bias),
            nn.BatchNorm2d(out_channels),
            nn.ReLU6())

    def forward(self, x):
        #encode stage
        e1 = self.encoder1(x) # [B,16,32,48]
        e2 = self.encoder2(e1) # [B,32,16,24]
        e3 = self.encoder3(e2) # [B,64,8,12]
        e4 = self.encoder4(e3) # [B,96,4,6]

        #decode stage
        d4 = torch.cat((self.decode4(e4),e3),dim=1) # [B,64+64,8,12]
        d3 = torch.cat((self.decode3(d4),e2),dim=1) #[B,32+32,16,24]
        d2 = torch.cat((self.decode2(d3),e1),dim=1) #[B,16+16,32,48]
        d1 = self.decode1(d2) #[B,16,64,96]

        #res
        res = self.res_conv(d1) #[B,11,64,96]
        return res


class CrossEntropyLoss2d(nn.Module):
    """
    defines a cross entropy loss for 2D images
    """
    def __init__(self, weight=None, ignore_label= 255):
        """
        :param weight: 1D weight vector to deal with the class-imbalance
        Obtaining log-probabilities in a neural network is easily achieved by adding a LogSoftmax layer in the last layer of your network.
        You may use CrossEntropyLoss instead, if you prefer not to add an extra layer.
        """
        super().__init__()

        #self.loss = nn.NLLLoss2d(weight, ignore_index=255)
        # self.loss = nn.NLLLoss(weight)
        self.loss = nn.CrossEntropyLoss(weight)

    def forward(self, outputs, targets):
        # return self.loss(F.log_softmax(outputs, 1), targets)
        return self.loss(outputs,targets)


if __name__ == '__main__':
    from torchstat import stat

    # initial model
    model = Model()

    input_data = torch.ones([5, 1, 64, 96], dtype=torch.float32)  # [B,C,H,W]

    stat(model,(1,64,96))

    exit(0)


    # initialize the optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # print the model's state_dict
    print("model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor, '\t', model.state_dict()[param_tensor].size())

    print("\noptimizer's state_dict")
    for var_name in optimizer.state_dict():
        print(var_name, '\t', optimizer.state_dict()[var_name])

运行该模块,可以查看网络模型的参数量和运算量,如下图所示:
参数量大小

训练模块

训练模块实现模型的训练及保存,此外,添加了summary,便于利用tensorboard对训练过程进行观察。
代码如下:
训练时的参数如学习率,batcsize等写在了配置文件中,运行代码时需要指定配置文件路径。

#!/usr/bin/env python3
# coding=utf-8

# ============================#
# Program:train.py
#       训练模型
# Date:20-4-16
# Author:liheng
# Version:V1.0
# ============================#

import Model
import m2nistDataSet
import utils_torch
import argparse
import numpy as np
import os
import shutil
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter


class Train(object):
    def __init__(self, config_file: str):
        # 读取配置
        self.config = utils_torch.get_config(config_file)

        # 加载数据
        self.train_dataset = m2nistDataSet.m2nistDataLoader(config_file, 'train')
        self.val_dataset = m2nistDataSet.m2nistDataLoader(config_file, 'val')

        # 创建文件夹
        os.makedirs(self.config['Train.model_save_dir'], exist_ok=True)
        if os.path.exists(self.config['Train.log_dir']):
            shutil.rmtree(self.config['Train.log_dir'])
        os.makedirs(self.config['Train.log_dir'])

        # 加载模型
        self.device = torch.device('cuda'
                                   if (torch.cuda.is_available() and self.config['USE_CUDA'])
                                   else 'cpu')
        self.model = Model.Model().to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config['Train.lr_init'])
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.99)
        self.loss_func = torch.nn.CrossEntropyLoss()
        self.global_setp = 0

        # summary
        self.summary_writer = SummaryWriter(self.config['Train.log_dir'])
        self.summary_writer.add_graph(self.model, (torch.rand([1, 1, 64, 96]),))  # grapth

    def train(self):
        # checkpoint = torch.load(path)
        # self.model.load_state_dict(checkpoint['model'])
        # optimizer.load_state_dict(checkpoint['optimizer'])
        # start_epoch = checkpoint['epoch'] + 1
        try:
            last_model = utils_torch.find_new_file(self.config['Train.model_save_dir'])
            self.model.load_state_dict(torch.load(last_model, map_location=self.device))
            print('[info] Restoring weights from last trained file ...')
        except Exception as e:
            print('[info] Can not find last trained file !!!')
            print('[info] Now it starts to train model from scratch ...')

        class_weights = 10 * [1.0] + [0.2]  # lable为10的class权重为0.2,0-9个class为1,输出一个list
        class_weights = torch.tensor(class_weights, dtype=torch.float32)

        for epoch in range(1, 1 + self.config['Train.max_epochs']):
            train_losses, val_losses = [], []
            pbar = tqdm(self.train_dataset)
            for batch in pbar:
                batch_x, batch_y = batch[0].to(self.device), batch[1].to(self.device)
                out = self.model(batch_x)

                loss = Model.CrossEntropyLoss2d(class_weights)(out, batch_y.long())
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                self.global_setp += 1

                train_losses.append(loss.item())

                pbar.set_description("Epoch:%d Step:%d loss:%.3f" % (epoch, self.global_setp, loss.item()))

                # tensorboardX
                self.summary_writer.add_scalar('learning rate', self.optimizer.state_dict()['param_groups'][0]['lr'],
                                               self.global_setp)
                self.summary_writer.add_scalar('train loss', loss, self.global_setp)
                self.summary_writer.add_images('train input images', batch_x, self.global_setp)
                self.summary_writer.add_images('train gt images',
                                               utils_torch.tran_masks2images(batch_y.numpy()),
                                               self.global_setp)
                self.summary_writer.add_images('train pred images',
                                               utils_torch.tran_masks2images(
                                                   torch.argmax(torch.softmax(out, dim=1), dim=1).numpy()),
                                               self.global_setp)

            # 在预测前需要把model设置为评估模式
            self.model.eval()
            with torch.no_grad():  # 无需计算梯度
                for batch_x, batch_y in self.val_dataset:
                    batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
                    out = self.model(batch_x)
                    loss = Model.CrossEntropyLoss2d(class_weights)(out, batch_y.long())

                    val_losses.append(loss.item())

            train_avg_loss, val_avg_loss = np.mean(train_losses), np.mean(val_losses)
            print('epoch:%d, train loss:%.5f, val loss:%.5f ' % (epoch, train_avg_loss, val_avg_loss))

            save_name = os.path.join(self.config['Train.model_save_dir'], 'm2nist-seg_epoch{:d}.pth'.format(epoch))
            # torch.save(self.model.cpu().state_dict(), save_name)#保存cpu的参数
            torch.save(self.model.state_dict(), save_name)

        self.summary_writer.close()


def init_args():
    """
epoch
    :return:
    """
    parser = argparse.ArgumentParser()

    parser.add_argument('--cfg_pth', type=str,
                        help='The config file path',
                        default='/home/liheng/PycharmProjects/m2nist-segmentation/pytorch/config.yaml')

    return parser.parse_args()


if __name__ == '__main__':
    args = init_args()
    assert os.path.isfile(args.cfg_pth), args.cfg_pth + 'does not exist !'
    trainer = Train(args.cfg_pth)
    trainer.train()

预测模块

预测模块没啥可说的,加载模型,然后预测、将结果可视化即可。代码此处不贴啦。

训练与预测

训练

代码训练过程可视化如下:
scalars
images

预测

预测结果如下:
可以看到,第一幅的结果还是能够接受的,而第二幅图像的分割结果就不够精细,对5和3的部分像素被归为其他数字。

预测结果1
预测结果2



下面的是我的公众号二维码图片,欢迎关注。
图注:幼儿园的学霸

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值