MXNet之数据读取与增强
数据读取与增强
1.直接读取原图像数据
1.1 生成.lst文件
im2rec.py脚本用于生成,lst文件.
文件目录:
执行下面指令,生成.lst文件.
#第一个data/train代表,prefix的意思,执行后会得到train.lst文件
#第2个data/train代表,数据的root,根据数据存储路径进行设置
#--list代表执行生成.lst操作,还可用来生成RecordIO文件
#--recursive表示迭代搜索给定的目录,例如train下是类别文件夹,然后是图片,需要设置此参数,才能搜索到.
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python /tools/im2rec.py data/train data/train --list --recursive
python: can't open file '/tools/im2rec.py': [Errno 2] No such file or directory
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py data/train data/train --list --recursive
cock 0
ostrich 1
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py data/val data/val --list --recursive
cock 0
ostrich 1
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
结果:
1.2 基本数据读取方式read_lst.py
import mxnet as mx
import matplotlib.pyplot as plt
train_data = mx.image.ImageIter(batch_size=32,
data_shape=(3,224,224),
path_imglist=‘data/train.lst’,
path_root=‘data/train’,
shuffle=True)
val_data = mx.image.ImageIter(batch_size=32,
data_shape=(3,224,224),
path_imglist=‘data/val.lst’,
path_root=‘data/val’)
train_data.reset()
print(train_data)
data_batch = train_data.next()
print(data_batch)
data = data_batch.data[0]
#print(data)
plt.figure()
for i in range(4):
save_image = data[i].astype(‘uint8’).asnumpy().transpose((1,2,0))
plt.subplot(1,4,i+1)
plt.imshow(save_image)
plt.savefig(‘image_sample.jpg’)
train_data = mx.image.ImageIter(batch_size=32,
data_shape=(3, 224, 224),
path_imglist=‘data/train.lst’,
path_root=‘data/train’,
shuffle=True,
resize=256,
rand_mirror=True)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
输出结果:
<mxnet.image.image.ImageIter object at 0x7fc7ac5839e8>
DataBatch: data shapes: [(32, 3, 224, 224)] label shapes: [(32,)]
- 1
- 2
2. 基于RecordIO文件读取数据
1.生成RecordIO文件
#--num-thread:设置线程数
#data/train.lst:表示.lst的文件路径
#data/train:原图像所在的目录
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py --num-thread 16 data/train.lst data/train
Creating .rec file from /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data/train.lst in /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data
time: 0.009440183639526367 count: 0
time: 0.8572912216186523 count: 1000
time: 0.8444476127624512 count: 2000
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py --num-thread 16 data/val.lst data/val
Creating .rec file from /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data/val.lst in /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data
time: 0.014867544174194336 count: 0
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
结果:
2.2 RecordIO数据读取方式read_rec.py
mx.io.ImageRecordIter()接口用于读取RecordIO数据文件,接口参数多大数十个,以下代码简介:
import mxnet as mx
import matplotlib.pyplot as plt
#增加了resize和随机镜像操作
train_data = mx.io.ImageRecordIter(batch_size=32,
data_shape=(3,224,224),
path_imgrec='data/train.rec',
path_imgidx='data/train.idx',
shuffle=True,
resize=256,
rand_mirror=True)
val_data = mx.io.ImageRecordIter(batch_size=32,
data_shape=(3,224,224),
path_imgrec=‘data/val.rec’,
path_imgidx=‘data/val.idx’,
resize=256)
train_data.reset()
data_batch = train_data.next()
data = data_batch.data[0]
plt.figure()
for i in range(4):
save_image = data[i].astype(‘uint8’).asnumpy().transpose((1,2,0))
plt.subplot(1,4,i+1)
plt.imshow(save_image)
plt.savefig(‘image_sample_rec.jpg’)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
输出结果:
<mxnet.io.io.MXDataIter object at 0x7fc7ac62fa58>
DataBatch: data shapes: [(32, 3, 224, 224)] label shapes: [(32,)]
- 1
- 2
3.数据增强
3.1大小重置(resize)
- mx.image.ResizeAug(size=)
- mx.image.ForceResizeAug(size=(224,224))
import matplotlib.pyplot as plt
import mxnet as mx
if name == ‘main’:
prefix = ‘data-augmentation/resize/’
image = ‘ILSVRC2012_val_00000002.jpg’
image_name = image.split(".")[0] #获取ILSVRC2012_val_00000002
image_string = open(‘data-augmentation/resize/{}’.format(image), ‘rb’).read() #以2进制读取图片
data = mx.image.imdecode(image_string, flag=1)
print(“Shape of data:{}”.format(data.shape))
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(prefix + image_name))
<span class="token comment">#shortersize短边到size,按比例缩小另一边</span>
shorterResize <span class="token operator">=</span> mx<span class="token punctuation">.</span>image<span class="token punctuation">.</span>ResizeAug<span class="token punctuation">(</span>size<span class="token operator">=</span><span class="token number">224</span><span class="token punctuation">,</span> interp<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token comment">#interp=2是插值算法,默认是2,可不写</span>
shorterResize_data <span class="token operator">=</span> shorterResize<span class="token punctuation">(</span>data<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Shape of data:{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>shorterResize_data<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>imshow<span class="token punctuation">(</span>shorterResize_data<span class="token punctuation">.</span>asnumpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>savefig<span class="token punctuation">(</span><span class="token string">'{}_shorterResize.png'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>prefix <span class="token operator">+</span> image_name<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment">#shortersize短边到size,按比例夸张另一边</span>
shorterResize <span class="token operator">=</span> mx<span class="token punctuation">.</span>image<span class="token punctuation">.</span>ResizeAug<span class="token punctuation">(</span>size<span class="token operator">=</span><span class="token number">1000</span><span class="token punctuation">)</span>
shorterResize_data <span class="token operator">=</span> shorterResize<span class="token punctuation">(</span>data<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Shape of data:{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>shorterResize_data<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>imshow<span class="token punctuation">(</span>shorterResize_data<span class="token punctuation">.</span>asnumpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>savefig<span class="token punctuation">(</span><span class="token string">'{}_shorterResize_bigsize.png'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>prefix <span class="token operator">+</span> image_name<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment">#强制转换成(224,224),通过插值算法,人会变形</span>
forceResize <span class="token operator">=</span> mx<span class="token punctuation">.</span>image<span class="token punctuation">.</span>ForceResizeAug<span class="token punctuation">(</span>size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">224</span><span class="token punctuation">,</span><span class="token number">224</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
forceResize_data <span class="token operator">=</span> forceResize<span class="token punctuation">(</span>data<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Shape of data:{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>forceResize_data<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>imshow<span class="token punctuation">(</span>forceResize_data<span class="token punctuation">.</span>asnumpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>savefig<span class="token punctuation">(</span><span class="token string">'{}_forceResize.png'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>prefix <span class="token operator">+</span> image_name<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment">#强制转换成(200,300)</span>
forceResize <span class="token operator">=</span> mx<span class="token punctuation">.</span>image<span class="token punctuation">.</span>ForceResizeAug<span class="token punctuation">(</span>size<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">200</span><span class="token punctuation">,</span> <span class="token number">300</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
forceResize_data <span class="token operator">=</span> forceResize<span class="token punctuation">(</span>data<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Shape of data:{}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>forceResize_data<span class="token punctuation">.</span>shape<span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>imshow<span class="token punctuation">(</span>forceResize_data<span class="token punctuation">.</span>asnumpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
plt<span class="token punctuation">.</span>savefig<span class="token punctuation">(</span><span class="token string">'{}_forceResize_diff.png'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>prefix <span class="token operator">+</span> image_name<span class="token punctuation">)</span><span class="token punctuation">)</span>
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
输出结果:
Shape of data:(1440, 1080, 3)
Shape of data:(298, 224, 3)
Shape of data:(1333, 1000, 3)
Shape of data:(224, 224, 3)
Shape of data:(300, 200, 3)
- 1
- 2
- 3
- 4
- 5
3.2 裁剪(crop)
- center crop
- random crop
- random resize crop
import matplotlib.pyplot as plt
import mxnet as mx
if name == ‘main’:
prefix = ‘data-augmentation/crop/’
image = ‘ILSVRC2012_val_00000009.jpg’
image_name = image.split(".")[0]
image_string = open(‘data-augmentation/crop/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
print(“Shape of data:{}”.format(data.shape))
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(prefix + image_name))
#从图像的中间区域裁剪,以图像中心点为裁剪中心
centerCrop = mx.image.CenterCropAug(size=(224,224))
class_centerCrop_data = centerCrop(data)
print(“Shape of data:{}”.format(class_centerCrop_data.shape))
plt.imshow(class_centerCrop_data.asnumpy())
plt.savefig(’{}_centerCrop.png’.format(prefix + image_name))
#随机裁剪,以图像任意点为裁剪中心
randomCrop = mx.image.RandomCropAug(size=(224,224))
class_randomCrop_data = randomCrop(data)
print(“Shape of data:{}”.format(class_randomCrop_data.shape))
plt.imshow(class_randomCrop_data.asnumpy())
plt.savefig(’{}_randomCrop.png’.format(prefix + image_name))
#size:输出图线尺寸;area:初次裁剪面积又area*原图面积得道;ratio:表示宽高比例,确定裁剪面积后,根据此比例确定宽高
randomSizeCrop = mx.image.RandomSizedCropAug(size=(224,224), area=0.08,
ratio=(3/4, 4/3))
class_randomSizedCrop_data = randomSizeCrop(data)
print(“Shape of data:{}”.format(class_randomSizedCrop_data.shape))
plt.imshow(class_randomSizedCrop_data.asnumpy())
plt.savefig(’{}_randomSizedCrop.png’.format(prefix + image_name))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
输出结果:
Shape of data:(1440, 1080, 3)
Shape of data:(224, 224, 3)
Shape of data:(224, 224, 3)
Shape of data:(224, 224, 3)
- 1
- 2
- 3
- 4
3.3 镜像(mirror)
- mx.image.HorizontalFlipAug(p=0.5)
import matplotlib.pyplot as plt
import mxnet as mx
if name == ‘main’:
image = ‘ILSVRC2012_val_00000014.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
print(“Shape of data:{}”.format(data.shape))
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))
#p表示执行随机镜像操作的概率
mirror = mx.image.HorizontalFlipAug(p=0.5)
mirror_data = mirror(data)
plt.imshow(mirror_data.asnumpy())
plt.savefig(’{}_mirror.png’.format(image_name))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
3.4 亮度 (brightness)
- mx.image.BrightnessJitterAug(brightness=0.3)
亮度不宜设置过大,否则失真
输出图像:输入图像的像素值乘以[1+brightness,1-brightness]中间的随机数的到的图像
import matplotlib.pyplot as plt
import mxnet as mx
if name == ‘main’:
image = ‘ILSVRC2012_val_00000008.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))
cast = mx.image.CastAug()
data = cast(data)
brightness = mx.image.BrightnessJitterAug(brightness=0.3)
brightness_data = brightness(data)
brightness_data = mx.nd.Cast(brightness_data, dtype=‘uint8’)
plt.imshow(brightness_data.asnumpy())
plt.savefig(’{}_brightness.png’.format(image_name))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
输出结果:
3.5 对比度(contrast)
- mx.image.ContrastJitterAug(contrast=0.3)
对比度不宜设置过大,否则失真
import matplotlib.pyplot as plt
import mxnet as mx
if name == ‘main’:
image = ‘ILSVRC2012_val_00000008.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))
cast = mx.image.CastAug()
data = cast(data)
contrast = mx.image.ContrastJitterAug(contrast=0.3)
contrast_data = contrast(data)
contrast_data = mx.nd.Cast(contrast_data, dtype=‘uint8’)
plt.imshow(contrast_data.asnumpy())
plt.savefig(’{}_contrast.png’.format(image_name))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
输出结果:
3.6 饱和度(saturation)
- saturation = mx.image.SaturationJitterAug(saturation=0.3)
饱和度指色彩纯度,纯度越高表现越鲜明,越低表现越黯淡.
import matplotlib.pyplot as plt
import mxnet as mx
if name == ‘main’:
image = ‘ILSVRC2012_val_00000008.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))
cast = mx.image.CastAug()
data = cast(data)
saturation = mx.image.SaturationJitterAug(saturation=0.3)
saturation_data = saturation(data)
saturation_data = mx.nd.Cast(saturation_data, dtype=‘uint8’)
plt.imshow(saturation_data.asnumpy())
plt.savefig(’{}_saturation.png’.format(image_name))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
输出结果:
</div>
<link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-e9f16cbbc2.css" rel="stylesheet">
</div>
</article>
<div class="postTime">
<div class="article-bar-bottom">
<span class="time">
文章最后发布于: 2019-06-26 10:07:52 </span>
</div>
</div>