制作自己的TFRecord数据集

一直在做CNN图像分类,原来采用的图像读入方式是放到文件夹下,直接将数据加载进内存,然后再分batch输入网络进行训练,但是后来发现太占用内存了,加了新内存条还是不够用。查阅资料,原来使用TFRecords这种结构能够有效地节省内存空间。

下面就来制作自己的数据集吧!找到10类电商图像,调整和mnist类似,单通道,28x28大小。数据路径为D:\data\TFRecord

import cv2  
import numpy as np 
import os
import os.path
import sys

def suoxiao(image):
    img=cv2.imread(image)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    pic = cv2.resize(img, (28, 28), interpolation=cv2.INTER_CUBIC)
    
    tempName=os.path.splitext(image)[0]
    tempName=tempName.split("/")[-1]
    tempName = tempName+"_x3.jpg"
    cv2.imwrite(tempName,pic)

if __name__ == "__main__":
    for tempImg in os.listdir("D:\\data\\TFRecord\\pixie"):
        tempImg = "D:\\data\\TFRecord\\pixie\\"+tempImg
        suoxiao(tempImg)    

TFRecord是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存。会根据你选择输入文件的类,自动给每一类打上同样的标签。

import os 
import tensorflow as tf 
from PIL import Image  
import matplotlib.pyplot as plt 
import numpy as np
 
cwd='D:\\data\\TFRecord\\'
classes={'banshenqun','duanku','gaogenxie','lianyiqun','niuzaiku'
         ,'pixie','piyi','xizhuang','yezi','yurongfu'} 
writer= tf.python_io.TFRecordWriter("mnist.tfrecords") 
 
for index,name in enumerate(classes):
    class_path=cwd+name+'\\'
    for img_name in os.listdir(class_path): 
        img_path=class_path+img_name 
 
        img=Image.open(img_path)
        img_raw=img.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        })) 
        writer.write(example.SerializeToString())  
 
writer.close()

然后会生成一个mnist.tfrecords文件,制作完毕。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值