csv转tfrecord

本文详细介绍了如何将CSV数据转换为TensorFlow的TFRecord格式,适用于大规模数据集的高效读取和处理。首先,我们将探讨CSV文件的结构,然后讨论TFRecord文件的优势。接着,我们将展示一个使用Python和TensorFlow库将CSV数据转换为TFRecord的实例代码,包括数据预处理步骤。最后,我们将提供读取TFRecord文件的示例,以便在TensorFlow模型中使用。
摘要由CSDN通过智能技术生成
#!/usr/bin/env python
# -*- coding:utf-8 -*-

# generate_tfrecord.py

# -*- coding: utf-8 -*-


"""
Usage:
  # From tensorflow/models/
  # Create train data:
  python generate_tfrecord.py --csv_input=data/tv_vehicle_labels.csv  --output_path=train.record
  # Create test data:
  python generate_tfrecord.py --csv_input=data/test_labels.csv  --output_path=test.record
"""


import io
import os
from collections import namedtuple

import pandas as pd
import tensorflow as tf
from PIL import Image
from object_detection.utils import dataset_util

os.chdir(r'E:\AI\打标')

flags = tf.app.flags
flags.DEFINE_string('csv_input', r'E:\AI\打标\train.csv', 'Path to the CSV input')
flags.DEFINE_string('output_path', r'E:\AI\打标\train.record', 'Path to output TFRecord')
FLAGS = flags.FLAGS


# TO-DO replace this with label map
def class_text_to_int(row_label):
    if row_label == 'person':     # 需改动
        return 1
    else:
        None


def split(df, group):
    data = namedtuple(
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
将.csv文件换成.record文件,需要经过以下步骤: 1. 准备数据 假设我们有一个.csv文件,其中包含了许多图片的标注信息,每一行代表一张图片的信息,包括图片的路径、标注框的坐标、标注框内的物体类别等。例如: ``` /path/to/image1.jpg,xmin1,ymin1,xmax1,ymax1,class1 /path/to/image2.jpg,xmin2,ymin2,xmax2,ymax2,class2 /path/to/image3.jpg,xmin3,ymin3,xmax3,ymax3,class3 ... ``` 2. 生成.tfrecord文件 使用TensorFlow提供的API,我们可以很容易地将.csv文件换成.tfrecord文件。具体方法如下: ```python import tensorflow as tf # 定义函数,将一行csv数据换成一个tf.train.Example对象 def create_example(image_path, xmin, ymin, xmax, ymax, class_name): with tf.io.gfile.GFile(image_path, 'rb') as f: encoded_image = f.read() example = tf.train.Example(features=tf.train.Features(feature={ 'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_image])), 'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf-8')])), 'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=[xmin])), 'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=[ymin])), 'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=[xmax])), 'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=[ymax])), 'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[class_name.encode('utf-8')])), 'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=[class_id])), })) return example # 读取csv文件,生成tf.train.Example对象,写入tfrecord文件 csv_file = '/path/to/csv_file.csv' output_file = '/path/to/output_file.record' with tf.io.gfile.GFile(output_file, 'wb') as f: writer = tf.io.TFRecordWriter(f.name) with tf.io.gfile.GFile(csv_file, 'r') as csvfile: reader = csv.reader(csvfile) for row in reader: image_path, xmin, ymin, xmax, ymax, class_name = row example = create_example(image_path, float(xmin), float(ymin), float(xmax), float(ymax), class_name) writer.write(example.SerializeToString()) writer.close() ``` 3. 验证.tfrecord文件 我们可以使用TensorFlow提供的API,验证生成的.tfrecord文件是否正确。具体方法如下: ```python import tensorflow as tf # 定义函数,从.tfrecord文件中读取tf.train.Example对象 def parse_example(example_proto): features = { 'image/encoded': tf.io.FixedLenFeature([], tf.string), 'image/format': tf.io.FixedLenFeature([], tf.string), 'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32), 'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32), 'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32), 'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32), 'image/object/class/text': tf.io.VarLenFeature(tf.string), 'image/object/class/label': tf.io.VarLenFeature(tf.int64), } parsed_features = tf.io.parse_single_example(example_proto, features) return parsed_features # 读取.tfrecord文件,验证其中包含的数据是否正确 tfrecord_file = '/path/to/output_file.record' dataset = tf.data.TFRecordDataset(tfrecord_file) dataset = dataset.map(parse_example) for example in dataset.take(10): print(example) ``` 如果输出的结果和原始的.csv文件中的数据一致,就说明生成的.tfrecord文件是正确的。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值