TensorFlow的学习之路--创建图像训练所需的tfrecords文件

# -*- coding: utf-8 -*-
"""
Created on Mon Mar 26 17:34:28 2018

@author: kxq
"""

import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm
from PIL import Image
import os
import io
import argparse


parser = argparse.ArgumentParser()
parser.add_argument(
    '--train_dir', type=str,
    default="D:/Justin图像资料/胃胶囊/参数确认/train",
    help='A path to a folder with training data.'
)
parser.add_argument(
    '--val_dir', type=str,
    default="D:/Justin图像资料/胃胶囊/参数确认/test",
    help='A path to a folder with validation data.'
)
parser.add_argument(
    '--save_dir', type=str,
    default='D:/Justin图像资料/胃胶囊/参数确认',
    help='A path to a folder where to save results.'
)
args = parser.parse_args()

def main():
    encoder = create_encoder(args.train_dir)
    # now you can get a folder's name from a class index

    np.save(os.path.join(args.save_dir, 'class_encoder.npy'), encoder)
    convert(args.train_dir, encoder, os.path.join(args.save_dir, 'train.tfrecords'))
    convert(args.val_dir, encoder, os.path.join(args.save_dir, 'val.tfrecords'))

    print('\nCreated two tfrecords files:')
    print(os.path.join(args.save_dir, 'train.tfrecords'))
    print(os.path.join(args.save_dir, 'val.tfrecords'))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# here you can also just use `return array.tostring()`
# but it will make tfrecords files a lot larger and
# you will need to change the input pipeline
def to_bytes(array):
    ##将数组转化为图像
    image = Image.fromarray(array)
    ##内存中读写数组
    tmp = io.BytesIO()
    image.save(tmp, format='jpeg')
    return tmp.getvalue()


def convert(folder, encoder, tfrecords_filename):
    
    images_metadata = collect_metadata(folder, encoder)
    ##tf.python_io.TFRecordWriter作用是生成tfrecord二进制文件
    writer = tf.python_io.TFRecordWriter(tfrecords_filename)
    ##iterrows的作用:遍历data返回元素组【index,row】
    ##得到的row有image_metadata的三个元素[class_name,img_path,class_number]
    for _, row in tqdm(images_metadata.iterrows()):
        ##路径合并,原路径+img_path
        file_path = os.path.join(folder, row.img_path)
        # read an image
        image = Image.open(file_path)

        # Image读取的类型是Image,需转换类型
        array = np.asarray(image, dtype='uint8')

        # some images are grayscale
        if array.shape[-1] != 3:
            array = np.stack([array, array, array], axis=2)

        # 获取标记
        target = int(row.class_number)

        feature = {
            'image': _bytes_feature(to_bytes(array)),
            'target': _int64_feature(target),
        }
        ##Example包含一个键值对数据结构(与dict相同), 使用属性features记录, 
        ##因此, 初始化时必须传入这个features参数, 它是一个tf.train.Features对象.
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(example.SerializeToString())
    ##等同于tf.python_io.TFRecordWriter.close()
    writer.close()


def create_encoder(folder):
    ##os.listdir作用:返回文件夹下的名字列表
    classes = os.listdir(folder)
    ## enumerate作用:既遍历索引又遍历元素
    encoder = {n: i for i, n in enumerate(classes)}
    return encoder

def collect_metadata(folder, encoder):
    ##os.walk的作用:遍历目录下下的列表,返回三个元素组分别为:{遍历的路劲名、该路径下的目录列表、该路径下的文件列表}
    ## list作用,创建一个列表,索引从0开始,0取得是整体目录,所以这里从1开始取  
    subdirs = list(os.walk(folder))[1:]
    metadata = []

    for dir_path, _, files in subdirs:
        ##,split的作用:通过指定分隔符对字符串进行切片,上述得到的路劲为D://xxx//xxx。所以这里只去最后一个,即为文件夹名
        dir_name = dir_path.split('\\')[-1]  #
        for file_name in files:
            ##os.path.join作用:合并路劲,下列得到:【文件夹名,文件夹名+文件名】
            image_metadata = [dir_name, os.path.join(dir_name, file_name)]
            ##.append作用:在列表后面添加新元素
            metadata.append(image_metadata)
    ##pd.dataframe作用:分组作用,把上述的【文件夹名,文件夹名+文件名】分成【文件夹名】【文件夹+文件名】
    ##得到的数据类型是dataframe
    M = pd.DataFrame(metadata)
    M.columns = ['class_name', 'img_path']
    ##lambda的作用:定义一个匿名函数,等同于x=encoder[x]
    ##在dataframe数据中,pandas dataframe.apply()的作用是实现对某一行/列进行处理获得一个新行/新列
    ##这里对class_name行进行操作并赋值给class_number
    M['class_number'] = M.class_name.apply(lambda x: encoder[x])
    ##reset_index作用:通过函数 drop=True 删除原行索引
    ##sample作用:可以从指定的序列中,随机的截取指定长度的片断,不作原地修改。
    M = M.sample(frac=1).reset_index(drop=True)
    return M


if __name__ == '__main__':
    main()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值