面对升级后的人工智能学习框架 TensorFlow,应如何将代码升级到新版本中?承接上篇文章,本文将继续手把手指引您的 TensorFlow 代码升级过程,让您从 TensorFlow 2 的简洁和便利中受益。
↓点击观看代码迁移视频教程↓
如何将代码迁移至 TensorFlow 2
保存和加载
检查点兼容性
TensorFlow 2.0 使用基于对象的检查点。
如果您足够小心,旧版的基于名称的检查点仍然可以加载。如果进行代码转换,可能需要对变量名进行变更,但是也有变通的办法。
最简单的方法是将新模型中的名字和检查点的名字保持一致:
-
您依旧可以为所有变量设置
name
参数。 -
Keras 模型也有一个
name
参数,模型将其设置成自有变量的前缀。 -
v1.name_scope
函数可以用于设置变量名前缀。这与tf.variable_scope
有极大的不同。因其只影响名称,不追踪变量和重用。
如果无法适用于您的用例,您可以尝试 v1.train.init_from_checkpoint
函数。该函数带有一个 assignment_map
参数,可以指定从旧名称到新名称的映射。
请注意:
基于对象的检查点可以延迟加载,但基于名称的检查点不能,要求在构建函数时调用所有变量。一些模型只有当您调用build
或在一个批次的数据上运行模型时才会构建变量。
TensorFlow Estimator 代码库包含一个转换工具,可以将预制 Estimator 的检查点从 TensorFlow 1.X 升级到 2.0。因此,可以作为如何为类似用例构建工具的示例。
已保存模型的兼容性
已保存模型没有太大的兼容性问题。
-
TensorFlow 1.x saved_models 可以在 TensorFlow 2.0 中使用。
-
如果 TensorFlow 2.0 能够支持 TensorFlow 1.x 的所有算子,TensorFlow 2.0 saved_models 甚至可以加载 TensorFlow 1.x 中的工作。
Graph.pb 或 Graph.pbtxt
没有直接的方法将原始的Graph.pb
文件升级到 TensorFlow 2.0。您最好的选择是升级生成文件的代码。
但如果您有一个 Frozen graph(tf.Graph
的一种,其中的变量变为常量),那么可以使用v1.wrap_function
将其转换为concrete_function
:
def wrap_frozen_graph(graph_def, inputs, outputs):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
return wrapped_import.prune(
tf.nest.map_structure(import_graph.as_graph_element, inputs),
tf.nest.map_structure(import_graph.as_graph_element, outputs))
此例中有个 Frozen graph:
path = tf.keras.utils.get_file(
'inception_v1_2016_08_28_frozen.pb',
'http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz',
untar=True)
Downloading data from http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz
24698880/24695710 [==============================] - 2s 0us/step
加载 tf.GraphDef
:
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(open(path,'rb').read())
打包至 concrete_function
:
inception_func = wrap_frozen_graph(
graph_def, inputs='input:0',
outputs='InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu:0')
向其传递张量,作为输入:
input_img = tf.ones([1,224,224,3], dtype=tf.float32)
inception_func(input_img).shape
TensorShape([1, 28, 28, 96])
Estimator
使用 Estimator 进行训练
TensorFlow 2.0 支持 Estimator。
当您使用 Estimator 时,可以使用 TensorFlow 1.x 中的input_fn()
、tf.estimator.TrainSpec
和tf.estimator.EvalSpec
。
下面是结合训练和评估规格使用input_fn
的一个例子。
创建 input_fn 以及训练 / 评估规格:
# Define the estimator's input_fn
def input_fn():
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE = 64
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label[..., tf.newaxis]
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
return train_data.repeat()
# Define train & eval specs
train_spec = tf.estimator.TrainSpec(input_fn=input_fn,
max_steps=STEPS_PER_EPOCH * NUM_EPOCHS)
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn,
steps=STEPS_PER_EPOCH)
使用 Keras 模型定义
TensorFlow 2.0 中构建 Estimator 的方法会有一些不同。
我们建议您使用 Keras 定义自己的模型,然后使用 tf.keras.estimator.model_to_estimator
实用程序将模型变为 Estimator。下面的代码展示创建和训练 Estimator 时如何使用这个实用工具。
def make_model():
return tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
model = make_model()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
estimator = tf.keras.estimator.model_to_estimator(
keras_model = model
)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp4q8g11bh
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp4q8g11bh
INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1635: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1635: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4q8g11bh', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, '_keep_checkpoint_max': 5,