"""
@author: khoing
@contact: Khoing@126.com
@time: 2019/12/16 14:36
@file: tf_data_generate_tfrecord.py
"""
import matplotlib as mpl # Matplotlib 是 Python 的绘图库。 它可与 NumPy 一起使用
import matplotlib.pyplot as plt # Python数据可视化matplotlib.pyplot
# %matplotlib inline #在使用jupyter notebook 或者 jupyter qtconsole的时候,经常会用到%matplotlib inline。其作用就是在你调用plot()进行画图或者直接输入Figure的实例对象的时候,会自动的显示并把figure嵌入到console中。
import numpy as np # 数值计算扩展。这种工具可用来存储和处理大型矩阵
import sklearn # 机器学习中常用的第三方模块,对常用的机器学习方法进行了封装,包括回归(Regression)、降维(Dimensionality Reduction)、分类(Classfication)、聚类(Clustering)等方法。
import pandas as pd # 是python的一个数据分析包
import os # 系统编程的操作模块,可以处理文件和目录
import sys # sys模块包含了与Python解释器和它的环境有关的函数
import time
import tensorflow as tf
from tensorflow import keras
##################################################################################################
# 选择GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
##################################################################################################
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
print(module.__name__, module.__version__)
"""output:
2.0.0
sys.version_info(major=3, minor=7, micro=4, releaselevel='final', serial=0)
matplotlib 3.1.1
numpy 1.16.5
pandas 0.25.3
sklearn 0.21.3
tensorflow 2.0.0
tensorflow_core.keras 2.2.4-tf
"""
##################################################################################################
source_dir = "./generate_csv/"
print(os.listdir(source_dir))
"""output:
['test_00.csv', 'test_01.csv', 'test_02.csv', 'test_03.csv', 'test_04.csv', 'test_05.csv', 'test_06.csv', 'test_07.csv', 'test_08.csv', 'test_09.csv',
'train_00.csv', 'train_01.csv', 'train_02.csv', 'train_03.csv', 'train_04.csv', 'train_05.csv', 'train_06.csv', 'train_07.csv', 'train_08.csv', 'train_09.csv', 'train_10.csv', 'train_11.csv', 'train_12.csv', 'train_13.csv', 'train_14.csv', 'train_15.csv', 'train_16.csv', 'train_17.csv', 'train_18.csv', 'train_19.csv',
'valid_00.csv', 'valid_01.csv', 'valid_02.csv', 'valid_03.csv', 'valid_04.csv', 'valid_05.csv', 'valid_06.csv', 'valid_07.csv', 'valid_08.csv', 'valid_09.csv'
]
"""
def get_filenames_by_prefix(source_dir, prefix_name):
results = []
all_files = os.listdir(source_dir)
for filename in all_files:
if filename.startswith(prefix_name):
results.append(os.path.join(source_dir,filename))
return results
train_filenames = get_filenames_by_prefix(source_dir, "train")
valid_filenames = get_filenames_by_prefix(source_dir, "valid")
test_filenames = get_filenames_by_prefix(source_dir, "test")
import pprint
pprint.pprint(train_filenames)
pprint.pprint(valid_filenames)
pprint.pprint(test_filenames)
"""output:
['./generate_csv/train_00.csv',
'./generate_csv/train_01.csv',
'./generate_csv/train_02.csv',
'./generate_csv/train_03.csv',
'./generate_csv/train_04.csv',
'./generate_csv/train_05.csv',
'./generate_csv/train_06.csv',
'./generate_csv/train_07.csv',
'./generate_csv/train_08.csv',
'./generate_csv/train_09.csv',
'./generate_csv/train_10.csv',
'./generate_csv/train_11.csv',
'./generate_csv/train_12.csv',
'./generate_csv/train_13.csv',
'./generate_csv/train_14.csv',
'./generate_csv/train_15.csv',
'./generate_csv/train_16.csv',
'./generate_csv/train_17.csv',
'./generate_csv/train_18.csv',
'./generate_csv/train_19.csv']
['./generate_csv/valid_00.csv',
'./generate_csv/valid_01.csv',
'./generate_csv/valid_02.csv',
'./generate_csv/valid_03.csv',
'./generate_csv/valid_04.csv',
'./generate_csv/valid_05.csv',
'./generate_csv/valid_06.csv',
'./generate_csv/valid_07.csv',
'./generate_csv/valid_08.csv',
'./generate_csv/valid_09.csv']
['./generate_csv/test_00.csv',
'./generate_csv/test_01.csv',
'./generate_csv/test_02.csv',
'./generate_csv/test_03.csv',
'./generate_csv/test_04.csv',
'./generate_csv/test_05.csv',
'./generate_csv/test_06.csv',
'./generate_csv/test_07.csv',
'./generate_csv/test_08.csv',
'./generate_csv/test_09.csv']
"""
##################################################################################################
#从csv文件中读取训练集,验证集和测试集
def parse_csv_line(line, n_fields = 9):
defs = [tf.constant(np.nan)] * n_fields
parsed_fields = tf.io
Tensorflow-datasetI使用_4.csv转为dataset再转为tfrecord文件,读取解析tfrecord文件,训练生成model
最新推荐文章于 2022-01-17 19:04:01 发布
![](https://img-home.csdnimg.cn/images/20240711042549.png)