tensorflow数据集Dataset和估算器Estimator

tensorflow数据集和估算器

  • 数据集:一种创建输入管道(即,将数据读入您的程序)的全新方式
  • 估算器:一种创建tensorflow模型的高级方式。估算器包括适用于常见机器学习任务的预制模型,不过,您也可以使用它们创建自己的自定义模型。
    下面是它们在tensorflow架构内的转配方式。

1.数据集介绍

数据集是一种我tensorflow模型创建输入管道的新方式。使用此API的性能要比使用feed_dict或队列式管道的性能高得多,而且此API更简洁,使用起来更容易。数据集在1.3版本中位于tf.contrib.data中,预计会在1.4版本中将此API移动到核心中。

从高层次而言,数据集由以下类组成:

其中:

  • Dataset:基类,包括用于创建和转换数据集的函数。允许从内存中的数据或从python生成器初始化数据集。
  • TextLineDataset:从文本文件中读取各行内容。
  • TFRecordDataset:从TFRecord文件中读取记录。
  • FixedLengthRecordDataset:从二进制文件中读取固定大小的记录。
  • Iterator:提供了一种一次获取一个数据集元素的方法。

在训练模型时,需要一个可以读取输入文件并返回特征和标签数据的函数。估算器要求您创建一个具有以下格式的函数:

def input_fn():
    ...<code>...
    return ({'SepalLength':[values], ..<etc>..,'PetalWidth':[values]},[IrisFlowerType])
#返回一个表示特征的字典列表和表示标签的列表

返回值必须是一个按照如下方式组织的两元素元组:

  • 第一个元素必须是一个字典(其中的每一个输入特征都是一个键)
  • 第二个元素是一个用于训练批次的标签列表

下面是使用Dataset API实现此函数的方式。我们会将它包装到一个“输入函数(my_input_fn)”中,这个输入函数用于为估算器模型提供数据:

def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
   def decode_csv(line):
       parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
       label = parsed_line[-1:] # Last element is the label
       del parsed_line[-1] # Delete last element
       features = parsed_line # Everything (but last element) are the features
       d = dict(zip(feature_names, features)), label
       return d

   dataset = (tf.contrib.data.TextLineDataset(file_path) # Read text file
       .skip(1) # Skip header row
       .map(decode_csv)) # Transform each elem by applying decode_csv fn
   if perform_shuffle:
       # Randomizes input using a window of 256 elements (read into memory)
       dataset = dataset.shuffle(buffer_size=256)
   dataset = dataset.repeat(repeat_count) # Repeats dataset this # times
   dataset = dataset.batch(32)  # Batch size to use
   iterator = dataset.make_one_shot_iterator()
   batch_features, batch_labels = iterator.get_next()
   return batch_features, batch_labels
  • TextLineDataset:在您使用Dataset API的文件式数据集时,它将为您执行大量的内存管理工作。例如,你可以读入比内存大得多的数据集文件,或者以参数形式指定列表,读入多个文件。
  • shuffle:读取buffer_size记录,然后打乱(随机化)它们的顺序。
  • map:调用decode_csv函数,并将数据集中的每个元素作为一个参数(由于我们使用的是TextLineDataset,每个元素都将是一行CSV文本)。然后,我们将向每一行应用decode_csv。
  • decode_csv:将每一行拆分成各个字段,根据需要提供默认值。然后,返回一个包含字段键和字段值的字典。map函数将使用字典更新数据集中的每个元素(行)。

2.估算器介绍

估算器是一种高级API,使用这种API,您在训练tensorflow模型时就不在像之前那样需要编写大量的样板文件代码。估算器也非常灵活,如果你对模型有具体的要求,它允许你替换默认行为。
使用估算器,可以通过两种可能的方式构建模型:

  • Pre-made Estimators(预制估算器):这些是预先定义的估算器,旨在生成特定类型的模型。
  • Estimator(基类):允许你使用model_fn函数完全掌控模型的创建方式。
    所有的估算器都使用input_fn,它为估算器提供输入数据。
    下面的代码可以将预测鸢尾花类型的估算器实例化:
# Create the feature_columns, which specifies the input to our model.
# All our input features are numeric, so use numeric_column for each one.
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names]

# Create a deep neural network regression classifier.
# Use the DNNClassifier pre-made estimator
classifier = tf.estimator.DNNClassifier(
   feature_columns=feature_columns, # The input features to our model
   hidden_units=[10, 10], # Two layers, each with 10 neurons
   n_classes=3,
   model_dir=PATH) # Path to where checkpoints etc are stored

现在有了一个可以开始训练的估算器。
(全文见通过机器学习让医疗数据更好用

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
DataFrame是Pandas库中的一种数据结构,用于处理和分析结构化数据。而TensorFlow是一个开源的机器学习框架,用于构建和训练机器学习模型。DataFrames和TensorFlow数据集可以一起使用,以便在数据处理和模型训练之间实现无缝的集成。 首先,可以使用Pandas库将数据加载到DataFrame中,然后对数据进行清洗、转换和探索性分析。DataFrame提供了各种功能,如数据过滤、排序、合并和统计汇总,可以简化对数据的操作和处理。 然后,可以将DataFrame中的数据转换为TensorFlow数据集,以便在TensorFlow中进行模型训练。TensorFlow提供了一个tf.data模块,用于创建和处理大规模的数据集。可以使用tf.data.Dataset.from_tensor_slices()函数将DataFrame转换为TensorFlow数据集。 在TensorFlow中,可以使用Dataset API提供的方法对数据集进行处理和转换,例如批处理、重复、随机化等。这些方法可以帮助我们准备用于训练的数据集,并确保数据在每个训练轮次中都能以随机的顺序传递给模型。 最后,可以使用TensorFlow构建和训练机器学习模型,通过迭代训练数据集中的样本来调整模型的权重和参数。利用DataFrame和TensorFlow数据集的集成,可以更好地管理和处理数据,提高模型训练的效率和准确性。 总之,DataFrame和TensorFlow数据集的结合可以提供一个完整的数据处理和模型训练的工作流程,使数据科学家和机器学习工程师能够更方便、高效地处理和分析结构化数据,并训练准确可靠的机器学习模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值