利用小型数据集m2nist进行语义分割——(一)数据集介绍
微信公众号:幼儿园的学霸
目录
缘由
代码地址:https://github.com/leonardohaig/m2nist-segmentation
最近晚上有点失眠,玩手机伤眼睛,那就学习吧。考虑到没有写过分割网络,尤其是没有用pytorch写过分割网络,那就用pytorch在小型数据集上练手一下分割网络吧。
前言
深度学习的典型任务一般认为是分类、检测和分割。初次入手深度学习时,一般都会自己搭建一个小型的神经网络来进行体验,对于分类任务,入门时一般利用的数据集为mnist
和fashion-mnist
。但是对于检测和分割任务,并没有通用的小型数据集,如果采用voc/kitti/bdd100k
等,数据集达到了数十G,明显不适合用来入门学习,经搜索,发现m2nist
数据集比较适合这种任务。
数据集介绍
m2nist
数据集可以在kaggle上搜索到,具体见参考链接1,它是mnist
的升级版。
它包含两个文件,conmbined.npy
和segmented.npy
,都是numpy
格式,可以直使用numpy.load()
读入。
读入后,前者shape=[5000,64,84]
,第1维表示图像的数量,后面为图像的尺寸,图像为单通道灰度图,每张图像最多包含mnist
数据集中的3个数字,后者shape=[5000,64,84,11]
,为数据集的标签,表示分割掩码,最后一个维度是11,表示经过one-hot编码过的mask。对于通道0~9,如果原图包含某个数字k,则第k通道[原图中k出现的位置]=1
,其他位置都是0;对于最后一个通道是背景的mask,意义是如果该位置没有数字出现,就应该是1,否则是0,这是因为在分割中背景被当做单独的类别对待。如某个图像及其label的各通道如下所示:
我采用的是参考链接3中处理后的数据集,和原始数据集相比,1)其将label处理为了单通道图像,使用时,需要自己进行one-hot处理,得到11通道结果;2)包含有每张图像上各数字的矩形框信息,因此可以用来进行做检测任务。
数据下载/读取/显示
可以参考连接3.我的代码如下:
其中将标签转换为one-hot编码部的代码拷贝于keras中代码片段。
#!/usr/bin/env python3
#coding=utf-8
#============================#
#Program:down_data.py
#
#Date:20-4-10
#Author:liheng
#Version:V1.0
#============================#
import numpy as np
import os
import requests
import zipfile
from six.moves import urllib
from tqdm import tqdm
import matplotlib.pyplot as plt
def to_categorical(y, num_classes=None, dtype='float32'):
"""Converts a class vector (integers) to binary class matrix.
E.g. for use with categorical_crossentropy.
# Arguments
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes.
dtype: The data type expected by the input, as a string
(`float32`, `float64`, `int32`...)
# Returns
A binary matrix representation of the input. The classes axis
is placed last.
# Example
```python
# Consider an array of 5 labels out of a set of 3 classes {0, 1, 2}:
> labels
array([0, 2, 1, 2, 0])
# `to_categorical` converts this into a matrix with as many
# columns as there are classes. The number of rows
# stays the same.
> to_categorical(labels)
array([[ 1., 0., 0.],
[ 0., 0., 1.],
[ 0., 1., 0.],
[ 0., 0., 1.],
[ 1., 0., 0.]], dtype=float32)
```
"""
y = np.array(y, dtype='int')
input_shape = y.shape
if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
input_shape = tuple(input_shape[:-1])
y = y.ravel()
if not num_classes:
num_classes = np.max(y) + 1
n = y.shape[0]
categorical = np.zeros((n, num_classes), dtype=dtype)
categorical[np.arange(n), y] = 1
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)
return categorical
def download_from_url(url, dst):
"""
@param: url to download file
@param: dst place to put the file
"""
file_size = int(urllib.request.urlopen(url).info().get('Content-Length', -1))
if os.path.exists(dst):
first_byte = os.path.getsize(dst)
else:
first_byte = 0
print(file_size)
if first_byte >= file_size:
return file_size
header = {"Range": "bytes=%s-%s" % (first_byte, file_size)}
pbar = tqdm(total=file_size, initial=first_byte, unit='B', unit_scale=True, desc=url.split('/')[-1])
req = requests.get(url, headers=header, stream=True)
with (open(dst, 'ab')) as f:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size
def download_m2nist_if_not_exist():
"""
:return:
"""
data_rootdir = os.path.join(os.path.split(os.path.realpath(__file__))[0],'m2nist')
if not os.path.exists(data_rootdir): # 保存路径不存在,则创建该路径
os.mkdir(data_rootdir)
m2nist_zip_path = os.path.join(data_rootdir, 'm2nist.zip')
if os.path.exists(m2nist_zip_path):
return
os.makedirs(data_rootdir, exist_ok=True)
m2nist_zip_url = 'https://raw.githubusercontent.com/akkaze/datasets/master/m2nist.zip'
download_from_url(m2nist_zip_url, m2nist_zip_path)
zipf = zipfile.ZipFile(m2nist_zip_path)
zipf.extractall(data_rootdir)
zipf.close()
def show_img_mask(img,mask):
mask = to_categorical(mask,11,dtype=np.uint8)
plt.figure(figsize=(4, 4))
plt.subplot(4, 4, 1)
plt.xticks([])
plt.yticks([])
plt.grid('off')
plt.imshow(img, cmap='Greys_r')
plt.xlabel('img')
for idx in range(11):
plt.subplot(4, 4, idx + 5)
plt.xticks([])
plt.yticks([])
plt.grid('off')
mask_vis = mask[:, :, idx]
plt.imshow(mask_vis, cmap='Greys_r')
plt.xlabel(str(idx))
# plt.get_current_fig_manager().full_screen_toggle()
plt.show()
def show_m2nist(imgs_pth,masks_pth):
"""
:param data_rootdir:
:return:
"""
assert os.path.isfile(imgs_pth),imgs_pth + ' path not exist !'
assert os.path.isfile(masks_pth), masks_pth + ' path not exist !'
imgs = np.load(imgs_pth).astype(np.uint8)
masks = np.load(masks_pth).astype(np.uint8)
for i in range(imgs.shape[0]):
# 转换为one-hot编码
mask = to_categorical(masks[i], 11, dtype=np.uint8)
plt.figure(figsize=(4, 4))
plt.subplot(4, 4, 1)
plt.xticks([])
plt.yticks([])
plt.grid('off')
plt.imshow(imgs[i], cmap='Greys_r')
plt.xlabel('img' + str(i))
for idx in range(11):
plt.subplot(4, 4, idx + 5)
plt.xticks([])
plt.yticks([])
plt.grid('off')
mask_vis = mask[:, :, idx]
plt.imshow(mask_vis, cmap='Greys_r')
plt.xlabel(str(idx))
# plt.get_current_fig_manager().full_screen_toggle()
plt.show()
def split_m2nist(data_rootdir):
"""
:param data_rootdir:
:return:
"""
assert os.path.exists(data_rootdir), data_rootdir + ' path not exist !'
imgs = np.load(os.path.join(data_rootdir, 'combined.npy'))
masks = np.load(os.path.join(data_rootdir, 'segmented.npy'))
val_ratio = 0.2
num_data = imgs.shape[0]
num_train = int(num_data * (1 - val_ratio))
train_imgs_pth = os.path.join(data_rootdir, 'train_imgs.npy')
train_masks_pth = os.path.join(data_rootdir,'train_masks.npy')
val_imgs_pth = os.path.join(data_rootdir,'val_imgs.npy')
val_masks_pth = os.path.join(data_rootdir, 'val_masks.npy')
np.save(train_imgs_pth,imgs[:num_train,...])
np.save(train_masks_pth,masks[:num_train,...])
np.save(val_imgs_pth,imgs[num_train:,...])
np.save(val_masks_pth,masks[num_train:,...])
if __name__ == '__main__':
data_rootdir = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'm2nist')
# download_m2nist_if_not_exist()
# split_m2nist(data_rootdir)
imgs_pth = os.path.join(data_rootdir,'train_imgs.npy')
masks_pth = os.path.join(data_rootdir,'train_masks.npy')
show_m2nist(imgs_pth,masks_pth)
print('Hello world !')
main函数中 注释掉的第1行是用来下载数据集,第2行是将数据集分为训练集(80%)和验证集(20%)。后面部分是观察图像及其标签。如下图所示,为读取的训练集中第一张图片及其标签one-hot的结果:
数据集下载完毕后进入正式代码的编写,接下来将进行网络框架的设计。
参考链接
1.Multidigit MNIST(M2NIST)
2.M2NIST Segmentation / U-net
3.一个超小型分割检测数据集
4.代码地址