tf.constant_initializer

参考  tf.train.Coordinator - 云+社区 - 腾讯云

目录

一、使用方法

二、类中的函数

1、__init__

2、__call__

3、from_config

4、get_config


一、使用方法

一个类,初始化器,它生成具有常量值的张量。由新张量的期望shape后面的参数value指定。参数value可以是常量值,也可以是类型为dtype的值列表。如果value是一个列表,那么列表的长度必须小于或等于由张量的期望形状所暗示的元素的数量。如果值中的元素总数小于张量形状所需的元素数,则值中的最后一个元素将用于填充剩余的元素。如果值中元素的总数大于张量形状所需元素的总数,初始化器将产生一个ValueError。

参数:

  • value: Python标量、值列表或元组,或n维Numpy数组。初始化变量的所有元素将在value参数中设置为对应的值。
  • dtype: 数据类型。
  • verify_shape: 布尔值,用于验证value的形状。如果为真,如果value的形状与初始化张量的形状不兼容,初始化器将抛出错误。

可能产生的异常:

  • TypeError: If the input value is not one of the expected types.

示例:下面的示例可以使用numpy重写。ndarray代替了值列表,甚至重新构造了值列表,如值列表初始化下面的两行注释所示。

import numpy as np
import tensorflow as tf
value = [0, 1, 2, 3, 4, 5, 6, 7]
# value = np.array(value)
# value = value.reshape([2, 4])
init = tf.constant_initializer(value)

print('fitting shape:')
with tf.Session():
    x = tf.get_variable('x', shape=[2, 4], initializer=init)
    x.initializer.run()
    print(x.eval())

Output:
-------------------
fitting shape:
[[ 0.  1.  2.  3.]
[ 4.  5.  6.  7.]]
-------------------


print('larger shape:')
with tf.Session():
   x = tf.get_variable('x', shape=[3, 4], initializer=init)
   x.initializer.run()
   print(x.eval())

Output:
-------------------
larger shape:
[[ 0.  1.  2.  3.]
[ 4.  5.  6.  7.]
[ 7.  7.  7.  7.]]
-------------------


print('smaller shape:')
with tf.Session():
    x = tf.get_variable('x', shape=[2, 3], initializer=init)


Error:
-----------------------------------------------------------------------------------------
Traceback (most recent call last):
  File "D:/tensorflow_learning/test.py", line 11, in <module>
    x = tf.get_variable('x', shape=[2, 3], initializer=init)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1484, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1234, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 538, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 492, in _true_getter
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 920, in _get_single_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 145, in __call__
    return cls._variable_call(*args, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 141, in _variable_call
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 120, in <lambda>
    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2441, in default_variable_creator
    expected_shape=expected_shape, import_scope=import_scope)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 147, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1104, in __init__
    constraint=constraint)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1212, in _init_from_args
    initial_value(), name="initial_value", dtype=dtype)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 894, in <lambda>
    shape.as_list(), dtype=dtype, partition_info=partition_info)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\init_ops.py", line 219, in __call__
    self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\constant_op.py", line 207, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 497, in make_tensor_proto
    (shape_size, nparray.size))
ValueError: Too many elements provided. Needed at most 6, but received 8
-----------------------------------------------------------------------------------------


print('shape verification:')
init_verify = tf.constant_initializer(value, verify_shape=True)
with tf.Session():
    x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)

Error:
-----------------------------------------------------------------------------------------
Traceback (most recent call last):
  File "D:/tensorflow_learning/test.py", line 12, in <module>
    x = tf.get_variable('x', shape=[3, 4], initializer=init_verify)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1484, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1234, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 538, in get_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 492, in _true_getter
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 920, in _get_single_variable
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 145, in __call__
    return cls._variable_call(*args, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 141, in _variable_call
    aggregation=aggregation)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 120, in <lambda>
    previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 2441, in default_variable_creator
    expected_shape=expected_shape, import_scope=import_scope)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 147, in __call__
    return super(VariableMetaclass, cls).__call__(*args, **kwargs)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1104, in __init__
    constraint=constraint)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py", line 1212, in _init_from_args
    initial_value(), name="initial_value", dtype=dtype)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 894, in <lambda>
    shape.as_list(), dtype=dtype, partition_info=partition_info)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\init_ops.py", line 219, in __call__
    self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\constant_op.py", line 207, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "D:\anaconda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 492, in make_tensor_proto
    (tuple(shape), nparray.shape))
TypeError: Expected Tensor's shape: (3, 4), got (8,).
-----------------------------------------------------------------------------------------

二、类中的函数

1、__init__

__init__(
    value=0,
    dtype=tf.float32,
    verify_shape=False
)

2、__call__

__call__(
    shape,
    dtype=None,
    partition_info=None,
    verify_shape=None
)

3、from_config

from_config(
    cls,
    config
)

从配置字典实例化初始化器。例子:

initializer = RandomUniform(-1, 1)
config = initializer.get_config()
initializer = RandomUniform.from_config(config)

参数:

  • config: 一个Python字典。它通常是get_config的输出。

返回:

  • 一个初始化后的实例。

4、get_config

get_config()

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Wanderer001

ROIAlign原理

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值