利用小型数据集m2nist进行语义分割——(三)代码编写及训练与预测
微信公众号:幼儿园的学霸
目录
前言
接下来按照上一篇的神经网络框架,进行具体的代码编写。属于对构思的具体实现,相对还是比较容易的。 该文贴出了部分代码,完整代码地址:https://github.com/leonardohaig/m2nist-segmentation
代码编写
代码不多,非常简单。甚至一个.py
文件都可以完成,为了清晰,我将其进行了划分。
编写完毕后代码文件夹内容如下所示:
数据加载模块
数据加载模块加载m2nist数据集,并对图片和标签进行处理:
1)图片和标签的尺寸缩放,将尺寸从(64,84)
填充到(64,96)
,不建议采用resize
操作,感兴趣的可以看下区别;
2)输入图像归一化到0~1区间,以及通道的变换,加载后的图像其shape顺序为[B,H,W]
,需要将其变换为[B,C,H,W]
的顺序;
3)将numpy类型的数据转换为tensor格式。
具体到代码编写过程,需要采用pytorch中的DataSet和DataLoader模块,由于数据集非常小,因此一次性全部读入内存,代码如下:
在向训练模块提供数据时,线程的数量是根据电脑cpu的数量来的。
#!/usr/bin/env python3
# coding=utf-8
# ============================#
# Program:m2nistDataSet.py
# 数据加载模块
# Date:20-4-16
# Author:liheng
# Version:V1.0
# ============================#
import os
import sys
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import utils_torch
from multiprocessing import cpu_count
__all__ = ['m2nistDataLoader']
class m2nistDatase(Dataset):
"""
"""
def __init__(self, imgs_pth, masks_pth):
assert os.path.isfile(imgs_pth)
assert os.path.isfile(masks_pth)
# load
imgs = np.load(imgs_pth)
masks = np.load(masks_pth)
# padding
# 从[64,84]填充大小到[64,96],
# 对于图像采用0填充;对于label采用10常值填充,因为10代表背景
imgs = np.pad(imgs, ((0, 0), (0, 0), (6, 6)), 'constant', constant_values=0)
masks = np.pad(masks, ((0, 0), (0, 0), (6, 6)), 'constant', constant_values=10)
self.imgs = np.expand_dims(imgs.astype(np.float32) / 255, axis=1) # [B,C,H,W]
self.masks = masks.astype(np.uint8)
def __getitem__(self, index):
img = torch.tensor(self.imgs[index])
mask = torch.tensor(self.masks[index])
return img, mask
def __len__(self):
return self.imgs.shape[0]
def m2nistDataLoader(cfg_pth, dataset_type='train'):
"""
:param cfg_pth:
:param dataset_type: train or val (验证集validation)
:return:
"""
assert os.path.isfile(cfg_pth), 'config file does not exist !'
config = utils_torch.get_config(cfg_pth)
if dataset_type == 'train':
imgs_pth = config['Train.images_pth']
masks_pth = config['Train.masks_pth']
batch_size = config['Train.batch_size']
else:
imgs_pth = config['Val.images_pth']
masks_pth = config['Val.masks_pth']
batch_size = config['Val.batch_size']
dataset = m2nistDatase(imgs_pth, masks_pth)
dataloader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=cpu_count() // 2)
return dataloader
if __name__ == '__main__':
os.chdir(os.path.split(os.path.abspath(__file__))[0])
sys.path.append('..')
import down_data
if 0:
data_rootdir = os.path.join(os.path.split(os.path.realpath(__file__))[0], '../', 'm2nist')
imgs_pth = os.path.join(data_rootdir, 'train_imgs.npy')
masks_pth = os.path.join(data_rootdir, 'train_masks.npy')
dataset = m2nistDatase(imgs_pth, masks_pth)
dataloader = DataLoader(dataset=dataset, batch_size=6,
shuffle=True, num_workers=cpu_count() // 2)
else:
dataloader = m2nistDataLoader('./config.yaml')
for i, img_mask in enumerate(dataloader):
img = np.squeeze(img_mask[0][0].numpy())
down_data.show_img_mask(img, img_mask[1][0].numpy())
网络实现模块
网络实现模块按照上一篇文章中的结构进行网络的复现。同时我将损失函数也放在了该模块中。
代码如下:
#!/usr/bin/env python3
#coding=utf-8
#============================#
#Program:Model.py
#
#Date:20-4-16
#Author:liheng
#Version:V1.0
#============================#
import layers
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(torch.nn.Module):
def __init__(self):
super(Model,self).__init__()
# encoder
self.encoder1 = self.EncoderBlock(1,16)
self.encoder2 = self.EncoderBlock(16,32)
self.encoder3 = self.EncoderBlock(32,64)
self.encoder4 = self.EncoderBlock(64,96)
# decoder
self.decode4 = self.DecodeBlock(96,64)
self.decode3 = self.DecodeBlock(128,32)
self.decode2 = self.DecodeBlock(64,16)
self.decode1 = self.DecodeBlock(32,16)
self.res_conv = layers.conv2d(16,11,3,1)
def EncoderBlock(self,in_channels, out_channels, t=6):
return torch.nn.Sequential(layers.sepconv2d(in_channels,out_channels,3,2,False),
layers.InvertedResidual(out_channels,out_channels,t=t,s=1))
def DecodeBlock(self,in_channels,out_channels,kernel_size=3,bias=True):
"""
:param in_channels:
:param out_channels:
:param kernel_size:
:param bias:
:return:
"""
return torch.nn.Sequential(
# conv1x1
nn.Conv2d(in_channels,in_channels//4,1,bias=bias),
nn.BatchNorm2d(in_channels//4),
nn.ReLU6(),
#deconv 3X3
nn.ConvTranspose2d(in_channels//4,in_channels//4,kernel_size,stride=2,padding=1,output_padding=1,bias=bias),
nn.BatchNorm2d(in_channels//4),
nn.ReLU6(),
# conv1x1
nn.Conv2d(in_channels//4,out_channels,1,bias=bias),
nn.BatchNorm2d(out_channels),
nn.ReLU6())
def forward(self, x):
#encode stage
e1 = self.encoder1(x) # [B,16,32,48]
e2 = self.encoder2(e1) # [B,32,16,24]
e3 = self.encoder3(e2) # [B,64,8,12]
e4 = self.encoder4(e3) # [B,96,4,6]
#decode stage
d4 = torch.cat((self.decode4(e4),e3),dim=1) # [B,64+64,8,12]
d3 = torch.cat((self.decode3(d4),e2),dim=1) #[B,32+32,16,24]
d2 = torch.cat((self.decode2(d3),e1),dim=1) #[B,16+16,32,48]
d1 = self.decode1(d2) #[B,16,64,96]
#res
res = self.res_conv(d1) #[B,11,64,96]
return res
class CrossEntropyLoss2d(nn.Module):
"""
defines a cross entropy loss for 2D images
"""
def __init__(self, weight=None, ignore_label= 255):
"""
:param weight: 1D weight vector to deal with the class-imbalance
Obtaining log-probabilities in a neural network is easily achieved by adding a LogSoftmax layer in the last layer of your network.
You may use CrossEntropyLoss instead, if you prefer not to add an extra layer.
"""
super().__init__()
#self.loss = nn.NLLLoss2d(weight, ignore_index=255)
# self.loss = nn.NLLLoss(weight)
self.loss = nn.CrossEntropyLoss(weight)
def forward(self, outputs, targets):
# return self.loss(F.log_softmax(outputs, 1), targets)
return self.loss(outputs,targets)
if __name__ == '__main__':
from torchstat import stat
# initial model
model = Model()
input_data = torch.ones([5, 1, 64, 96], dtype=torch.float32) # [B,C,H,W]
stat(model,(1,64,96))
exit(0)
# initialize the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, '\t', model.state_dict()[param_tensor].size())
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
print(var_name, '\t', optimizer.state_dict()[var_name])
运行该模块,可以查看网络模型的参数量和运算量,如下图所示:
训练模块
训练模块实现模型的训练及保存,此外,添加了summary,便于利用tensorboard对训练过程进行观察。
代码如下:
训练时的参数如学习率,batcsize等写在了配置文件中,运行代码时需要指定配置文件路径。
#!/usr/bin/env python3
# coding=utf-8
# ============================#
# Program:train.py
# 训练模型
# Date:20-4-16
# Author:liheng
# Version:V1.0
# ============================#
import Model
import m2nistDataSet
import utils_torch
import argparse
import numpy as np
import os
import shutil
from tqdm import tqdm
import torch
from tensorboardX import SummaryWriter
class Train(object):
def __init__(self, config_file: str):
# 读取配置
self.config = utils_torch.get_config(config_file)
# 加载数据
self.train_dataset = m2nistDataSet.m2nistDataLoader(config_file, 'train')
self.val_dataset = m2nistDataSet.m2nistDataLoader(config_file, 'val')
# 创建文件夹
os.makedirs(self.config['Train.model_save_dir'], exist_ok=True)
if os.path.exists(self.config['Train.log_dir']):
shutil.rmtree(self.config['Train.log_dir'])
os.makedirs(self.config['Train.log_dir'])
# 加载模型
self.device = torch.device('cuda'
if (torch.cuda.is_available() and self.config['USE_CUDA'])
else 'cpu')
self.model = Model.Model().to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config['Train.lr_init'])
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.99)
self.loss_func = torch.nn.CrossEntropyLoss()
self.global_setp = 0
# summary
self.summary_writer = SummaryWriter(self.config['Train.log_dir'])
self.summary_writer.add_graph(self.model, (torch.rand([1, 1, 64, 96]),)) # grapth
def train(self):
# checkpoint = torch.load(path)
# self.model.load_state_dict(checkpoint['model'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# start_epoch = checkpoint['epoch'] + 1
try:
last_model = utils_torch.find_new_file(self.config['Train.model_save_dir'])
self.model.load_state_dict(torch.load(last_model, map_location=self.device))
print('[info] Restoring weights from last trained file ...')
except Exception as e:
print('[info] Can not find last trained file !!!')
print('[info] Now it starts to train model from scratch ...')
class_weights = 10 * [1.0] + [0.2] # lable为10的class权重为0.2,0-9个class为1,输出一个list
class_weights = torch.tensor(class_weights, dtype=torch.float32)
for epoch in range(1, 1 + self.config['Train.max_epochs']):
train_losses, val_losses = [], []
pbar = tqdm(self.train_dataset)
for batch in pbar:
batch_x, batch_y = batch[0].to(self.device), batch[1].to(self.device)
out = self.model(batch_x)
loss = Model.CrossEntropyLoss2d(class_weights)(out, batch_y.long())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.scheduler.step()
self.global_setp += 1
train_losses.append(loss.item())
pbar.set_description("Epoch:%d Step:%d loss:%.3f" % (epoch, self.global_setp, loss.item()))
# tensorboardX
self.summary_writer.add_scalar('learning rate', self.optimizer.state_dict()['param_groups'][0]['lr'],
self.global_setp)
self.summary_writer.add_scalar('train loss', loss, self.global_setp)
self.summary_writer.add_images('train input images', batch_x, self.global_setp)
self.summary_writer.add_images('train gt images',
utils_torch.tran_masks2images(batch_y.numpy()),
self.global_setp)
self.summary_writer.add_images('train pred images',
utils_torch.tran_masks2images(
torch.argmax(torch.softmax(out, dim=1), dim=1).numpy()),
self.global_setp)
# 在预测前需要把model设置为评估模式
self.model.eval()
with torch.no_grad(): # 无需计算梯度
for batch_x, batch_y in self.val_dataset:
batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
out = self.model(batch_x)
loss = Model.CrossEntropyLoss2d(class_weights)(out, batch_y.long())
val_losses.append(loss.item())
train_avg_loss, val_avg_loss = np.mean(train_losses), np.mean(val_losses)
print('epoch:%d, train loss:%.5f, val loss:%.5f ' % (epoch, train_avg_loss, val_avg_loss))
save_name = os.path.join(self.config['Train.model_save_dir'], 'm2nist-seg_epoch{:d}.pth'.format(epoch))
# torch.save(self.model.cpu().state_dict(), save_name)#保存cpu的参数
torch.save(self.model.state_dict(), save_name)
self.summary_writer.close()
def init_args():
"""
epoch
:return:
"""
parser = argparse.ArgumentParser()
parser.add_argument('--cfg_pth', type=str,
help='The config file path',
default='/home/liheng/PycharmProjects/m2nist-segmentation/pytorch/config.yaml')
return parser.parse_args()
if __name__ == '__main__':
args = init_args()
assert os.path.isfile(args.cfg_pth), args.cfg_pth + 'does not exist !'
trainer = Train(args.cfg_pth)
trainer.train()
预测模块
预测模块没啥可说的,加载模型,然后预测、将结果可视化即可。代码此处不贴啦。
训练与预测
训练
代码训练过程可视化如下:
预测
预测结果如下:
可以看到,第一幅的结果还是能够接受的,而第二幅图像的分割结果就不够精细,对5和3的部分像素被归为其他数字。