MXNet之数据读取与增强

原创

MXNet之数据读取与增强

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/weixin_39451323/article/details/93659168

1.直接读取原图像数据

1.1 生成.lst文件

im2rec.py脚本用于生成,lst文件.
文件目录:
在这里插入图片描述
在这里插入图片描述

执行下面指令,生成.lst文件.

#第一个data/train代表,prefix的意思,执行后会得到train.lst文件
#第2个data/train代表,数据的root,根据数据存储路径进行设置
#--list代表执行生成.lst操作,还可用来生成RecordIO文件
#--recursive表示迭代搜索给定的目录,例如train下是类别文件夹,然后是图片,需要设置此参数,才能搜索到.
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python /tools/im2rec.py data/train data/train --list --recursive
python: can't open file '/tools/im2rec.py': [Errno 2] No such file or directory
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py data/train data/train --list --recursive
cock 0
ostrich 1
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py data/val data/val --list --recursive
cock 0
ostrich 1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

结果:
在这里插入图片描述
在这里插入图片描述

1.2 基本数据读取方式read_lst.py

import mxnet as mx
import matplotlib.pyplot as plt

train_data = mx.image.ImageIter(batch_size=32,
data_shape=(3,224,224),
path_imglist=‘data/train.lst’,
path_root=‘data/train’,
shuffle=True)
val_data = mx.image.ImageIter(batch_size=32,
data_shape=(3,224,224),
path_imglist=‘data/val.lst’,
path_root=‘data/val’)
train_data.reset()
print(train_data)
data_batch = train_data.next()
print(data_batch)
data = data_batch.data[0]
#print(data)
plt.figure()
for i in range(4):
save_image = data[i].astype(‘uint8’).asnumpy().transpose((1,2,0))
plt.subplot(1,4,i+1)
plt.imshow(save_image)
plt.savefig(‘image_sample.jpg’)

train_data = mx.image.ImageIter(batch_size=32,
data_shape=(3, 224, 224),
path_imglist=‘data/train.lst’,
path_root=‘data/train’,
shuffle=True,
resize=256,
rand_mirror=True)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

输出结果:

<mxnet.image.image.ImageIter object at 0x7fc7ac5839e8>
DataBatch: data shapes: [(32, 3, 224, 224)] label shapes: [(32,)]

 
 
  • 1
  • 2

在这里插入图片描述

2. 基于RecordIO文件读取数据

1.生成RecordIO文件

#--num-thread:设置线程数
#data/train.lst:表示.lst的文件路径
#data/train:原图像所在的目录
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py --num-thread 16 data/train.lst data/train
Creating .rec file from /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data/train.lst in /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data
time: 0.009440183639526367  count: 0
time: 0.8572912216186523  count: 1000
time: 0.8444476127624512  count: 2000
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py --num-thread 16 data/val.lst data/val
Creating .rec file from /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data/val.lst in /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data
time: 0.014867544174194336  count: 0

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

结果:
在这里插入图片描述
在这里插入图片描述

2.2 RecordIO数据读取方式read_rec.py

mx.io.ImageRecordIter()接口用于读取RecordIO数据文件,接口参数多大数十个,以下代码简介:

import mxnet as mx
import matplotlib.pyplot as plt
#增加了resize和随机镜像操作
train_data = mx.io.ImageRecordIter(batch_size=32,
                                   data_shape=(3,224,224),
                                   path_imgrec='data/train.rec',
                                   path_imgidx='data/train.idx',
                                   shuffle=True,
                                   resize=256,
                                   rand_mirror=True)

val_data = mx.io.ImageRecordIter(batch_size=32,
data_shape=(3,224,224),
path_imgrec=‘data/val.rec’,
path_imgidx=‘data/val.idx’,
resize=256)

train_data.reset()
data_batch = train_data.next()
data = data_batch.data[0]
plt.figure()
for i in range(4):
save_image = data[i].astype(‘uint8’).asnumpy().transpose((1,2,0))
plt.subplot(1,4,i+1)
plt.imshow(save_image)
plt.savefig(‘image_sample_rec.jpg’)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

输出结果:

<mxnet.io.io.MXDataIter object at 0x7fc7ac62fa58>
DataBatch: data shapes: [(32, 3, 224, 224)] label shapes: [(32,)]

 
 
  • 1
  • 2

在这里插入图片描述

3.数据增强

3.1大小重置(resize)

  • mx.image.ResizeAug(size=)
  • mx.image.ForceResizeAug(size=(224,224))
import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
prefix = ‘data-augmentation/resize/’
image = ‘ILSVRC2012_val_00000002.jpg’
image_name = image.split(".")[0] #获取ILSVRC2012_val_00000002
image_string = open(‘data-augmentation/resize/{}’.format(image), ‘rb’).read() #以2进制读取图片
data = mx.image.imdecode(image_string, flag=1)
print(“Shape of data:{}”.format(data.shape))
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(prefix + image_name))

<span class="token comment">#shortersize短边到size,按比例缩小另一边</span>
shorterResize <span class="token operator">=</span> mx<span class="token punctuation">.</span>image<span class="token punctuation">.</span>ResizeAug<span class="token punctuation">(</span>size<span class="token operator">=</span><span class="token number">224</span><span class="token punctuation">,</span> interp<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token comment">#interp=2是插值算法,默认是2,可不写</span>
shorterResize_data <span class="token operator">=</span> shorterResize<span class="token punctuation">(</span>data<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Shape of data:{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>shorterResize_data<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>imshow<span class="token punctuation">(</span>shorterResize_data<span class="token punctuation">.</span>asnumpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>savefig<span class="token punctuation">(</span><span class="token string">'{}_shorterResize.png'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>prefix <span class="token operator">+</span> image_name<span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment">#shortersize短边到size,按比例夸张另一边</span>
shorterResize <span class="token operator">=</span> mx<span class="token punctuation">.</span>image<span class="token punctuation">.</span>ResizeAug<span class="token punctuation">(</span>size<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">)</span>
shorterResize_data <span class="token operator">=</span> shorterResize<span class="token punctuation">(</span>data<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Shape of data:{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>shorterResize_data<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>imshow<span class="token punctuation">(</span>shorterResize_data<span class="token punctuation">.</span>asnumpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>savefig<span class="token punctuation">(</span><span class="token string">'{}_shorterResize_bigsize.png'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>prefix <span class="token operator">+</span> image_name<span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment">#强制转换成(224,224),通过插值算法,人会变形</span>
forceResize <span class="token operator">=</span> mx<span class="token punctuation">.</span>image<span class="token punctuation">.</span>ForceResizeAug<span class="token punctuation">(</span>size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">224</span><span class="token punctuation">,</span><span class="token number">224</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
forceResize_data <span class="token operator">=</span> forceResize<span class="token punctuation">(</span>data<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Shape of data:{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>forceResize_data<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>imshow<span class="token punctuation">(</span>forceResize_data<span class="token punctuation">.</span>asnumpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>savefig<span class="token punctuation">(</span><span class="token string">'{}_forceResize.png'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>prefix <span class="token operator">+</span> image_name<span class="token punctuation">)</span><span class="token punctuation">)</span>

<span class="token comment">#强制转换成(200,300)</span>
forceResize <span class="token operator">=</span> mx<span class="token punctuation">.</span>image<span class="token punctuation">.</span>ForceResizeAug<span class="token punctuation">(</span>size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">200</span><span class="token punctuation">,</span> <span class="token number">300</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
forceResize_data <span class="token operator">=</span> forceResize<span class="token punctuation">(</span>data<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Shape of data:{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>forceResize_data<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>imshow<span class="token punctuation">(</span>forceResize_data<span class="token punctuation">.</span>asnumpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>savefig<span class="token punctuation">(</span><span class="token string">'{}_forceResize_diff.png'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>prefix <span class="token operator">+</span> image_name<span class="token punctuation">)</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

输出结果:

Shape of data:(1440, 1080, 3)
Shape of data:(298, 224, 3)
Shape of data:(1333, 1000, 3)
Shape of data:(224, 224, 3)
Shape of data:(300, 200, 3)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.2 裁剪(crop)

  • center crop
  • random crop
  • random resize crop
import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
prefix = ‘data-augmentation/crop/’
image = ‘ILSVRC2012_val_00000009.jpg’
image_name = image.split(".")[0]
image_string = open(‘data-augmentation/crop/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
print(“Shape of data:{}”.format(data.shape))
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(prefix + image_name))

#从图像的中间区域裁剪,以图像中心点为裁剪中心
centerCrop = mx.image.CenterCropAug(size=(224,224))
class_centerCrop_data = centerCrop(data)
print(“Shape of data:{}”.format(class_centerCrop_data.shape))
plt.imshow(class_centerCrop_data.asnumpy())
plt.savefig(’{}_centerCrop.png’.format(prefix + image_name))

#随机裁剪,以图像任意点为裁剪中心
randomCrop = mx.image.RandomCropAug(size=(224,224))
class_randomCrop_data = randomCrop(data)
print(“Shape of data:{}”.format(class_randomCrop_data.shape))
plt.imshow(class_randomCrop_data.asnumpy())
plt.savefig(’{}_randomCrop.png’.format(prefix + image_name))

#size:输出图线尺寸;area:初次裁剪面积又area*原图面积得道;ratio:表示宽高比例,确定裁剪面积后,根据此比例确定宽高
randomSizeCrop = mx.image.RandomSizedCropAug(size=(224,224), area=0.08,
ratio=(3/4, 4/3))
class_randomSizedCrop_data = randomSizeCrop(data)
print(“Shape of data:{}”.format(class_randomSizedCrop_data.shape))
plt.imshow(class_randomSizedCrop_data.asnumpy())
plt.savefig(’{}_randomSizedCrop.png’.format(prefix + image_name))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

输出结果:

Shape of data:(1440, 1080, 3)
Shape of data:(224, 224, 3)
Shape of data:(224, 224, 3)
Shape of data:(224, 224, 3)

 
 
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.3 镜像(mirror)

  • mx.image.HorizontalFlipAug(p=0.5)
import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
image = ‘ILSVRC2012_val_00000014.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
print(“Shape of data:{}”.format(data.shape))
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))

#p表示执行随机镜像操作的概率
mirror = mx.image.HorizontalFlipAug(p=0.5)
mirror_data = mirror(data)
plt.imshow(mirror_data.asnumpy())
plt.savefig(’{}_mirror.png’.format(image_name))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

在这里插入图片描述
在这里插入图片描述

3.4 亮度 (brightness)

  • mx.image.BrightnessJitterAug(brightness=0.3)
    亮度不宜设置过大,否则失真
    输出图像:输入图像的像素值乘以[1+brightness,1-brightness]中间的随机数的到的图像
import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
image = ‘ILSVRC2012_val_00000008.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))

cast = mx.image.CastAug()
data = cast(data)
brightness = mx.image.BrightnessJitterAug(brightness=0.3)
brightness_data = brightness(data)
brightness_data = mx.nd.Cast(brightness_data, dtype=‘uint8’)
plt.imshow(brightness_data.asnumpy())
plt.savefig(’{}_brightness.png’.format(image_name))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

输出结果:在这里插入图片描述
在这里插入图片描述

3.5 对比度(contrast)

  • mx.image.ContrastJitterAug(contrast=0.3)
    对比度不宜设置过大,否则失真
import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
image = ‘ILSVRC2012_val_00000008.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))

cast = mx.image.CastAug()
data = cast(data)
contrast = mx.image.ContrastJitterAug(contrast=0.3)
contrast_data = contrast(data)
contrast_data = mx.nd.Cast(contrast_data, dtype=‘uint8’)
plt.imshow(contrast_data.asnumpy())
plt.savefig(’{}_contrast.png’.format(image_name))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

输出结果:
在这里插入图片描述
在这里插入图片描述

3.6 饱和度(saturation)

  • saturation = mx.image.SaturationJitterAug(saturation=0.3)
    饱和度指色彩纯度,纯度越高表现越鲜明,越低表现越黯淡.

import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
image = ‘ILSVRC2012_val_00000008.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))

cast = mx.image.CastAug()
data = cast(data)
saturation = mx.image.SaturationJitterAug(saturation=0.3)
saturation_data = saturation(data)
saturation_data = mx.nd.Cast(saturation_data, dtype=‘uint8’)
plt.imshow(saturation_data.asnumpy())
plt.savefig(’{}_saturation.png’.format(image_name))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

输出结果:
在这里插入图片描述
在这里插入图片描述

                                </div>
            <link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-e9f16cbbc2.css" rel="stylesheet">
                </div>
</article>
<div class="postTime"> 
    <div class="article-bar-bottom">
        <span class="time">
            文章最后发布于: 2019-06-26 10:07:52            </span>
    </div>
</div>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值