# TensorFlow2.0教程-使用tf.function和AutoGraph提高代码性能

import contextlib

# 构建包含上下文管理器的函数，使其可以在with中使用
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n  {}: {}'.format(error_class, e))
except Exception as e:
print('Got unexpected exception \n  {}: {}'.format(type(e), e))
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))



## tf.function

# 类似一个tensorflow操作
@tf.function
return a+b


<tf.Tensor: id=14, shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
[2., 2.]], dtype=float32)>

# tf.function操作可以计算梯度
@tf.function
return a+b
v = tf.Variable(2.0)


<tf.Tensor: id=40, shape=(), dtype=float32, numpy=1.0>

# 可以内嵌调用tf.function
@tf.function
def dense_layer(x, w, b):

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))


<tf.Tensor: id=67, shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
[3., 3.],
[3., 3.]], dtype=float32)>


## 跟踪和多态

Python的动态类型意味着可以使用各种参数类型调用函数，Python将在每个场景中执行不同的操作。

# 函数的多态
@tf.function
def double(a):
print('追踪变量：',a)
return a + a

print('结果:',double(tf.constant(1)))
print()
print('结果:',double(tf.constant(1.1)))
print()
print('结果:',double(tf.constant('c')))
print()

追踪变量： Tensor("a:0", shape=(), dtype=int32)



print('构建许可的追踪')
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("执行追踪函数")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("使用不合法参数")
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))

构建许可的追踪

tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)

Caught expected exception
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute __inference_double_98 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_98]

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(tf.equal(x % 2, 0), x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# 只能输入1维向量
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception
<class 'ValueError'>: Python inputs incompatible with input_signature: inputs ((<tf.Tensor: id=125, shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4]], dtype=int32)>,)), input_signature ((TensorSpec(shape=(None,), dtype=tf.int32, name=None),))


## 使用Python参数还是Tensor参数？

def train_one_step():
pass

@tf.function
def train(num_steps):
print("追踪： num_steps = {}".format(num_steps))
for _ in tf.range(num_steps):
train_one_step()

train(num_steps=10)
train(num_steps=20)

追踪： num_steps = 10


# 使用tensor，同类型不会重复追踪
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))

追踪： num_steps = Tensor("num_steps:0", shape=(), dtype=int32)

# 使用tensor，类型不同才会有新的追踪，（前一个单元格已追踪int型，所以该处不追踪）
train(num_steps=tf.constant(10, dtype=tf.int32))
train(num_steps=tf.constant(20.6))

追踪： num_steps = Tensor("num_steps:0", shape=(), dtype=float32)


## 副作用 tf.function

tf.function函数中的print()被用于跟踪，所以要调试输出每次调用(副作用),就需要tf.function()

@tf.function
def f(x):
print("追踪：", x)
tf.print('执行：', x)

f(1)
f(1)
f(2)

追踪： 1



external_list = []

def side_effect(x):
print('Python side effect')
external_list.append(x)

@tf.function
def f(x):
tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
print(external_list)

WARNING: Logging before flag parsing goes to stderr.
W0609 06:41:05.048375 139792217777920 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32
W0609 06:41:05.053524 139792217777920 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32
W0609 06:41:05.056409 139792226170624 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32

Python side effect
Python side effect
Python side effect
[<tf.Tensor: id=351, shape=(), dtype=int32, numpy=1>, <tf.Tensor: id=352, shape=(), dtype=int32, numpy=1>, <tf.Tensor: id=353, shape=(), dtype=int32, numpy=1>]


## 谨防Python状态

external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
tf.print('external_var:', external_var)

iterator = iter([0,1,2,3])
buggy_consume_next(iterator)
# 后面没有正常迭代，输出的都是第一个
buggy_consume_next(iterator)
buggy_consume_next(iterator)

external_var: 0
external_var: 0
external_var: 0


def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) 的图中包含了 {} 个节点".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss

small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))

train([(1, 1), (1, 1)]) 的图中包含了 8 个节点
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) 的图中包含了 32 个节点
train(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) 的图中包含了 4 个节点
train(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) 的图中包含了 4 个节点


## 自动控制依赖项

# 按顺序自动执行
a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
a.assign(y * b)
return a + b

f(1.0, 2.0)

<tf.Tensor: id=739, shape=(), dtype=float32, numpy=10.0>


@tf.function
def f(x):
# tf.function会重复调用相同变量，而eager每次都会创建新的变量
v = tf.Variable(1.0)
return v

with assert_raises(ValueError):
f(1.0)

Caught expected exception
<class 'ValueError'>: in converted code:

<ipython-input-25-8e74447e7577>:4 f  *
v = tf.Variable(1.0)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:262 __call__
return cls._variable_v2_call(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
shape=shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:60 getter
return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:364 invalid_creator_scope
"tf.function-decorated function tried to create "

ValueError: tf.function-decorated function tried to create variables on non-first call.


v = tf.Variable(1.0)  # 把变量拿到tf.function外面

@tf.function
def f(x):

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0

tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)


class C: pass
obj = C(); obj.v = None

@tf.function
def g(x):
if obj.v is None:
obj.v = tf.Variable(1.0)

print(g(1.0))  # 2.0
print(g(2.0))  # 4.0

tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)


state = []
@tf.function
def fn(x):
if not state:
state.append(tf.Variable(2.0 * x))
state.append(tf.Variable(state[0] * 3.0))
return state[0] * x * state[1]

print(fn(tf.constant(1.0)))
print(fn(tf.constant(3.0)))

tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(36.0, shape=(), dtype=float32)


## 使用AutoGraph

tf.cond并且tf.while_loop继续使用tf.function，但是当以命令式样式编写时，具有控制流的代码通常更容易编写和理解。

# 简单的循环
@tf.function
def f(x):
# 直接用python中的while写循环
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))

[0.829342961 0.858322263 0.900950909 0.851897 0.530384183]
[0.680123031 0.695392191 0.716760576 0.692059278 0.485674709]
[0.591599405 0.601434886 0.614898741 0.599303305 0.450776756]
[0.53104496 0.538069844 0.547566235 0.536553681 0.422537297]
[0.486179501 0.491525501 0.498693913 0.490374774 0.399065822]
[0.451178908 0.455426365 0.461089343 0.454513818 0.379149348]
[0.422867566 0.426349223 0.430971652 0.425602287 0.361968517]
[0.399343461 0.402265817 0.406133026 0.401639521 0.346946776]
[0.379387051 0.381885976 0.385184318 0.381350905 0.333665]
[0.362175018 0.36434418 0.367201209 0.363880038 0.321810097]
[0.347128421 0.349034756 0.351541221 0.348627061 0.311142713]
[0.333826423 0.335519224 0.337741673 0.335157365 0.30147627]
[0.321954757 0.323471278 0.325459719 0.323147237 0.292663]
[0.311273336 0.312642276 0.314435244 0.312349856 0.284584]
[0.301595032 0.302838922 0.304466605 0.302573323 0.277142316]
[0.292771578 0.293908447 0.295394808 0.293665737 0.270258158]
[0.284683794 0.285728157 0.287092626 0.285505235 0.263865024]
[0.277234435 0.278198302 0.279456645 0.277992576 0.257907033]
[0.270343572 0.271236718 0.272402078 0.271046132 0.25233686]
[0.263944477 0.264775217 0.265858531 0.264597982 0.247114092]
[0.257981181 0.258756459 0.259766966 0.258591145 0.242203966]
[0.252406299 0.253132015 0.254077554 0.252977312 0.237576365]
[0.24717927 0.247860536 0.248747766 0.247715324 0.233205199]
[0.242265314 0.242906466 0.24374117 0.242769822 0.229067564]
[0.237634286 0.238239139 0.239026278 0.238110229 0.225143358]
[0.233259991 0.233831868 0.234575793 0.233709976 0.221414775]
[0.229119495 0.229661271 0.230365857 0.229545817 0.217866093]
[0.225192651 0.22570689 0.22637549 0.225597292 0.214483246]
[0.221461684 0.221950635 0.222586185 0.221846417 0.211253688]
[0.217910782 0.218376443 0.218981609 0.218277216 0.208166167]
[0.214525893 0.214970052 0.215547174 0.214875415 0.205210552]
[0.211294428 0.211718708 0.212269917 0.211628318 0.202377662]
[0.208205134 0.208611 0.209138155 0.20852454 0.199659243]
[0.205247864 0.205636591 0.206141427 0.2055538 0.197047815]
[0.20241344 0.202786222 0.203270242 0.202706844 0.194536477]

<tf.Tensor: id=1006, shape=(5,), dtype=float32, numpy=
array([0.19969359, 0.2000515 , 0.2005161 , 0.19997531, 0.192119  ],
dtype=float32)>

print(f)

<tensorflow.python.eager.def_function.Function object at 0x7f23e7df2240>


def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x

print(tf.autograph.to_code(f))

def tf__f(x):
do_return = False
retval_ = ag__.UndefinedReturnValue()

def loop_test(x_1):
return ag__.converted_call('reduce_sum', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None) > 1

def loop_body(x_1):
ag__.converted_call('print', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None)
x_1 = ag__.converted_call('tanh', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None)
return x_1,
x, = ag__.while_stmt(loop_test, loop_body, (x,))
do_return = True
retval_ = x
cond = ag__.is_undefined_return(retval_)

def get_state():
return ()

def set_state(_):
pass

def if_true():
retval_ = None
return retval_

def if_false():
return retval_
retval_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state)
return retval_


AutoGraph：条件
AutoGraph会将if语句转换为等效的tf.cond调用。

# 测试
def test_tf_cond(f, *args):
# 获取图
g = f.get_concrete_function(*args).graph
if any(node.name=='cond' for node in g.as_graph_def().node):
print("{}({}) 使用 tf.cond.".format(
f.__name__, ', '.join(map(str, args))))
else:
print("{}({}) 正常执行.".format(
f.__name__, ', '.join(map(str, args))))


@tf.function
def hyperparam_cond(x, training=True):
if training:
x = tf.nn.dropout(x, rate=0.5)
return x

@tf.function
def maybe_tensor_cond(x):
if x < 0:
x = -x
return x

test_tf_cond(hyperparam_cond, tf.ones([1], dtype=tf.float32))
test_tf_cond(maybe_tensor_cond, tf.constant(-1)) # 条件为tensor
test_tf_cond(maybe_tensor_cond, -1)

hyperparam_cond(tf.Tensor([1.], shape=(1,), dtype=float32)) 正常执行.
maybe_tensor_cond(tf.Tensor(-1, shape=(), dtype=int32)) 使用 tf.cond.
maybe_tensor_cond(-1) 正常执行.


tf.cond有一些细微之处。 - 它的工作原理是跟踪条件的两边，然后根据条件在运行时选择适当的分支。跟踪双方可能导致意外执行Python代码 - 它要求如果一个分支创建下游使用的张量，另一个分支也必须创建该张量。

@tf.function
def f():
x = tf.constant(0)
if tf.constant(True):
x = x + 1
tf.print('执行，x：', x)
print("Tracing then branch")
else:
x = x - 1
tf.print('执行，x：', x)  # 没有执行
print("Tracing else branch")  # 该分支虽然不执行但也被追踪
return x

f()

Tracing then branch
Tracing else branch

<tf.Tensor: id=1128, shape=(), dtype=int32, numpy=1>


@tf.function
def f():
if tf.constant(True):
x = tf.ones([3, 3])
return x

# 两个分支必须都定义x， 否则会抛出异常
with assert_raises(ValueError):
f()

Caught expected exception
<class 'ValueError'>: in converted code:

<ipython-input-40-c7af591027c1>:3 f  *
if tf.constant(True):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:439 if_stmt
return tf_if_stmt(cond, body, orelse, get_state, set_state)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:456 tf_if_stmt
outputs, final_state = control_flow_ops.cond(cond, body, orelse)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py:507 new_func
return func(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:1147 cond
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:86 cond_v2
op_return_value=pred)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:716 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:486 wrapper
outputs = func()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:512 wrapper
tuple(s.symbol_name for s in undefined)))

ValueError: The following symbols must also be initialized in the else branch: ('x',). Alternatively, you may initialize them before the if statement.


AutoGraph和循环
AutoGraph有一些简单的转换循环规则。

• for：如果iterable是张量，则转换
• while：如果while条件取决于张量，则转换

# 测试
def test_dynamically_unrolled(f, *args):
g = f.get_concrete_function(*args).graph
if any(node.name == 'while' for node in g.as_graph_def().node):
print("{}({}) uses tf.while_loop.".format(
f.__name__, ', '.join(map(str, args))))
elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):
print("{}({}) uses tf.data.Dataset.reduce.".format(
f.__name__, ', '.join(map(str, args))))
else:
print("{}({}) gets unrolled.".format(
f.__name__, ', '.join(map(str, args))))

@tf.function
def for_in_range():
x = 0
for i in range(5):
x += i
return x

@tf.function
def for_in_tfrange():
x = tf.constant(0, dtype=tf.int32)
for i in tf.range(5):  # 生成迭代的张量
x += i
return x

@tf.function
def for_in_tfdataset():
x = tf.constant(0, dtype=tf.int64)
for i in tf.data.Dataset.range(5):
x += i
return x

test_dynamically_unrolled(for_in_range)
test_dynamically_unrolled(for_in_tfrange)
test_dynamically_unrolled(for_in_tfdataset)

for_in_range() gets unrolled.
for_in_tfrange() uses tf.while_loop.
for_in_tfdataset() uses tf.data.Dataset.reduce.

@tf.function
def while_py_cond():
x = 5
while x > 0:
x -= 1
return x

@tf.function
def while_tf_cond():
x = tf.constant(5)
while x > 0:   # while中的x为张量
x -= 1
return x

test_dynamically_unrolled(while_py_cond)
test_dynamically_unrolled(while_tf_cond)

while_py_cond() gets unrolled.
while_tf_cond() uses tf.while_loop.


@tf.function
def buggy_while_py_true_tf_break(x):
while True:
if tf.equal(x, 0):
break
x -= 1
return x

@tf.function
def while_tf_true_tf_break(x):
while tf.constant(True):  # 有break，顶级条件必须为张量
if tf.equal(x, 0):
break
x -= 1
return x

with assert_raises(TypeError):
test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)
test_dynamically_unrolled(while_tf_true_tf_break, 5)

Caught expected exception
<class 'TypeError'>: in converted code:

<ipython-input-45-f42fe93cfd97>:3 buggy_while_py_true_tf_break  *
while True:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:313 while_stmt
return _py_while_stmt(test, body, init_state, opts)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:401 _py_while_stmt
while test(*state):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:698 __bool__
raise TypeError("Using a tf.Tensor as a Python bool is not allowed. "

TypeError: Using a tf.Tensor as a Python bool is not allowed. Use if t is not None: instead of if t: to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

while_tf_true_tf_break(5) uses tf.while_loop.

@tf.function
def buggy_py_for_tf_break():
x = 0
for i in range(5):
if tf.equal(i, 3):
break
x += i
return x

@tf.function
def tf_for_tf_break():
x = 0
for i in tf.range(5):  # 有break，顶级迭代器必须为张量
if tf.equal(i, 3):
break
x += i
return x

with assert_raises(TypeError):
test_dynamically_unrolled(buggy_py_for_tf_break)
test_dynamically_unrolled(tf_for_tf_break)


Caught expected exception
<class 'TypeError'>: in converted code:

<ipython-input-46-902b45f3c32e>:4 buggy_py_for_tf_break  *
for i in range(5):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:110 for_stmt
return _py_for_stmt(iter_, extra_test, body, init_state)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:117 _py_for_stmt
if extra_test is not None and not extra_test(*state):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:698 __bool__
raise TypeError("Using a tf.Tensor as a Python bool is not allowed. "

TypeError: Using a tf.Tensor as a Python bool is not allowed. Use if t is not None: instead of if t: to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

tf_for_tf_break() uses tf.while_loop.


# 实现一个动态rnn
batch_size = 32
seq_len = 3
feature_size=4
# rnn步，输入与状态叠加
def rnn_step(inputs, state):
return inputs + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])  # 每个时间维度，都是整个batch数据喂入
max_seq_len = input_data.shape[0]

# 保存循环中的状态，必须使用tf.TensorArray
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
# 迭代时间步
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
# 把 batch_size重新换到前面
return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))


<tf.Tensor: id=1998, shape=(32, 3, 4), dtype=float32, numpy=
array([[[0.42647886, 0.73600817, 0.10211909, 0.89989746],
[0.772506  , 1.6853498 , 0.48793948, 1.4499462 ],
[1.1096102 , 2.3388233 , 0.5920907 , 1.588302  ]],

...
[[0.15579033, 0.4594922 , 0.17970431, 0.19183934],
[0.19597077, 0.5362154 , 0.19988954, 0.38290274],
[0.7524748 , 1.0519221 , 0.76595306, 0.5257962 ]]], dtype=float32)>


@tf.function
def buggy_loop_var_uninitialized():
for i in tf.range(3):
x = i  # 必须在循环上方初始化好x
return x

@tf.function
def f():
x = tf.constant(0)
for i in tf.range(3):
x = i
return x

with assert_raises(ValueError):
buggy_loop_var_uninitialized()
f()

Caught expected exception
<class 'ValueError'>: in converted code:

<ipython-input-53-05437e37672a>:3 buggy_loop_var_uninitialized  *
for i in tf.range(3):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:95 for_stmt
return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:125 _known_len_tf_for_stmt
_disallow_undefs_into_loop(*init_state)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:50 _disallow_undefs_into_loop
tuple(s.symbol_name for s in undefined)))

ValueError: TensorFlow requires that the following symbols must be defined before the loop: ('x',)

<tf.Tensor: id=2062, shape=(), dtype=int32, numpy=2>


@tf.function
def buggy_loop_type_changes():
x = tf.constant(0, dtype=tf.float32)
for i in tf.range(3): # Yields tensors of type tf.int32...
x = i
return x

with assert_raises(tf.errors.InvalidArgumentError):
buggy_loop_type_changes()

Caught expected exception
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Input 1 of node while/merge/_10 was passed int32 from while/next_iteration/_28:0 incompatible with expected float. [Op:__inference_buggy_loop_type_changes_2119]


@tf.function
def buggy_concat():
x = tf.ones([0, 10])
for i in tf.range(5):
x = tf.concat([x, tf.ones([1, 10])], axis=0)  # 循环时变量形状不能改变
return x

with assert_raises(ValueError):
buggy_concat()

@tf.function
x = tf.zeros([5, 10])
for i in tf.range(5):
x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)
x.set_shape([5, 10])
return x


Caught expected exception
<class 'ValueError'>: in converted code:

<ipython-input-55-74d839116efa>:4 buggy_concat  *
for i in tf.range(5):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:95 for_stmt
return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:156 _known_len_tf_for_stmt
opts=dict(maximum_iterations=n))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:327 _tf_while_stmt
retval = control_flow_ops.while_loop(test, body, init_state, **opts)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2646 while_loop
return_same_structure=return_same_structure)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:213 while_loop
len_orig_loop_vars], expand_composites=True))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:869 _check_shapes_compat
"specify a less-specific shape." % (input_t.name, shape, t.shape))

ValueError: Input tensor 'ones:0' enters the loop with shape (0, 10), but has shape (1, 10) after one iteration. To allow the shape to vary across iterations, use the shape_invariants argument of tf.while_loop to specify a less-specific shape.

<tf.Tensor: id=2240, shape=(5, 10), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>

03-16 3186

03-26 7438
02-12 1万+
06-19 2万+
11-07 350