TensorFlow 2.0 默认是即刻执行模式。相比以前要先构建模型结构图,执行后才看到结果。该模式下,构建图的同时就能输出对应结果,简化了调试流程。输入时也可以直接输入数据,不用先构造变量。不过,经初步测试,该模式会降低执行效率,推荐只在调试时使用。
可通过在方法前加 @tf.function 标识切换到JIT编译模式,该模式是1.0的默认模式,具有高效率,用于生产环境。
下面是测试代码:
import tensorflow as tf
import numpy as np
import time
import shutil
import os
class MyModel(tf.keras.Model):
def __init__(self,units):
super(MyModel, self).__init__(self)
self.dense=tf.keras.layers.Dense(units,activation=None)
def call(self, input_data):
# print('input_data',input_data)
output = self.dense(input_data)
return output
print(tf.__version__)
# 定义模型
my_model1=MyModel(3)
my_model2=MyModel(1)
losses = tf.keras.losses.MeanAbsoluteError()
optimizer = tf.keras.optimizers.Adadelta(learning_rate=1)
# 使用 @tf.function 标识,进行JIT编译,执行效率高
# 去掉 @tf.function 标识为即刻模式,可用于调试,执行效率较低
@tf.function
def train(input_data,target_data):
with tf.GradientTape() as tape:
# print('input_data',input_data.shape)
prediction = my_model1(input_data)
prediction = my_model2(prediction)
# tf.print('prediction1', prediction)
loss = losses(prediction, target_data)
# 记录日志,会影响效率
tf.summary.scalar('loss', loss, step=optimizer.iterations)
variables = my_model1.trainable_variables + my_model2.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
# 记录日志,会影响效率
if os.path.exists('./tmp/summaries'):
shutil.rmtree('./tmp/summaries')
summary_writer = tf.summary.create_file_writer('./tmp/summaries')
# 打印执行时间
start = time.process_time()
with summary_writer.as_default():
for i in range(500):
input_data = np.array([[1.,1.]])
target_data = np.array([[2.]])
train(input_data,target_data)
elapsed = (time.process_time() - start)
print("运行时间:",elapsed)
# 保存模型权重
my_model1.save_weights('./tmp/save_models1.h5')
my_model2.save_weights('./tmp/save_models2.h5')
# 加载模型权重
my_model1.load_weights('./tmp/save_models1.h5')
my_model2.load_weights('./tmp/save_models2.h5')
# 下面不在 @tf.function 标识的方法内,执行为即刻模式
print('识别')
input_data = np.array([[1.,1.]])
prediction = my_model1(input_data)
prediction = my_model2(prediction)
print('prediction2', prediction.numpy())