应用TensorFlow高级api构建全连接神经网络(2)--api解释

tf.layers.dense

tf.layers.dense(
    inputs,
    units,
    activation=None,
    use_bias=True,
    kernel_initializer=None,
    bias_initializer=tf.zeros_initializer(),
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    trainable=True,
    name=None,
    reuse=None
)
  • inputs:必需,即需要进行操作的输入数据。
  • units:必须,即神经元的数量。
  • activation:可选,默认为 None,如果为 None 则是线性激活。
  • use_bias:可选,默认为 True,是否使用偏置。
  • kernel_initializer:可选,默认为 None,即权重的初始化方法,如果为 None,则使用默认的 Xavier 初始化方法。
  • bias_initializer:可选,默认为零值初始化,即偏置的初始化方法。
  • kernel_regularizer:可选,默认为 None,施加在权重上的正则项。
  • bias_regularizer:可选,默认为 None,施加在偏置上的正则项。
  • activity_regularizer:可选,默认为 None,施加在输出上的正则项。
  • kernel_constraint,可选,默认为 None,施加在权重上的约束项。
  • bias_constraint,可选,默认为 None,施加在偏置上的约束项。
  • trainable:可选,默认为 True,布尔类型,如果为 True,则将变量添加到 GraphKeys.TRAINABLE_VARIABLES 中。
  • name:可选,默认为 None,卷积层的名称。
  • reuse:可选,默认为 None,布尔类型,如果为 True,那么如果 name 相同时,会重复利用。

参考:https://blog.csdn.net/xierhacker/article/details/82747919

tf.estimator.Estimator

# Estimator 类,用来训练和验证 TensorFlow 模型。
class Estimator(object):
    def __init__(self, model_fn, model_dir=None, config=None, params=None,
               warm_start_from=None):
  • model_fn: 模型函数。函数的格式如下:
    • 参数:
      • 1、features: 这是 input_fn 返回的第一项(input_fn 是 train, evaluate 和 predict 的参数)。类型应该是单一的 Tensor 或者 dict。
      • 2、labels: 这是 input_fn 返回的第二项。类型应该是单一的 Tensor 或者 dict。如果 mode 为 ModeKeys.PREDICT,则会默认为 labels=None。如果 model_fn 不接受 mode,model_fn 应该仍然可以处理 labels=None。
      • 3、mode: 可选。指定是训练、验证还是测试。参见 ModeKeys。
      • 4、params: 可选,超参数的 dict。 可以从超参数调整中配置 Estimators。
      • 5、config: 可选,配置。如果没有传则为默认值。可以根据 num_ps_replicas 或 model_dir 等配置更新 model_fn。
    • 返回:
      • EstimatorSpec
  • model_dir: 保存模型参数、图等的地址,也可以用来将路径中的检查点加载至 estimator 中来继续训练之前保存的模型。如果是 PathLike, 那么路径就固定为它了。如果是 None,那么 config 中的 model_dir 会被使用(如果设置了的话),如果两个都设置了,那么必须相同;如果两个都是 None,则会使用临时目录。
  • config: 配置类。
  • params: 超参数的dict,会被传递到 model_fn。keys 是参数的名称,values 是基本 python 类型
  • warm_start_from: 可选,字符串,检查点的文件路径,用来指示从哪里开始热启动。或者是 tf.estimator.WarmStartSettings 类来全部配置热启动。如果是字符串路径,则所有的变量都是热启动,并且需要 Tensor 和词汇的名字都没有变。

参考:https://blog.csdn.net/HappyRocking/article/details/80500172

tf.estimator.Estimator.evaluate

def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None):

使用验证集 input_fn 对 model 进行验证。

  • input_fn:此函数构造出验证所需的输入数据。

  • steps:验证模型的步数。如果是 None,则一直验证下去,直至input_fn 抛出了出界异常。

  • hooks:SessionRunHook子类实例的 list。作为验证的回调函数。

  • checkpoint_path:特定检查点的路径。如果是 None,则默认为 model_dir 中最近的检查点(model_dir 是 tf.estimator.Estimator 的构造函数的参数之一)

  • name:验证的名字。使用者可以针对不同的数据集运行多个验证操作,比如训练集 vs 测试集。不同验证的结果被保存在不同的文件夹中,且分别出现在 tensorboard 中。

返回一个字典,包括 model_fn 中指定的评价指标、global_step

tf.estimator.Estimator.train

def train(self,
            input_fn,
            hooks=None,
            steps=None,
            max_steps=None,
            saving_listeners=None):
  • input_fn:输入函数返回一个元组:features - Dictionary 的字符串特征名到 Tensor 或 SparseTensor。labels - Tensor 或带标签的张量字典。
  • hooks:SessionRunHook 子类实例列表。用于训练循环内的回调。
  • steps:用于训练模型的步骤数。如果为 None,永远训练或训练直到 input_fn 生成 OutOfRange 或 StopIteration 错误。“steps”是逐步进行的。如果你调用两次 train(steps = 10),那么 train 总共有20步。如果 OutOfRange 或 StopIteration 在中间出现差错,train将在前20步之前停止。如果你不想要增量表现,请设置 max_steps 代替。如果设置,max_steps 必须为 None。
  • max_steps:用于 train 模型的总步骤数。如果为 None,永远训练或训练,直到 input_fn 生成 OutOfRange 或 StopIteration 错误。如果设置,steps 必须为None。如果 OutOfRange 或者 StopIteration 在中间出现差错,训练之前应停止 max_steps 步骤。两次调用 train (steps=100) 意味着 200次 train 迭代。另一方面,两个调用 train(max_steps=100)意味着第二次调用将不会执行任何迭代,因为第一次调用完成了所有的100个步骤。
  • saving_listeners: CheckpointSaverListener对象的列表,用于在检查点保存之前或之后立即运行的回调。

https://blog.csdn.net/weixin_42499236/article/details/84189310

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值