qt使用assimp加载模型_使用分布策略保存和加载模型

本教程介绍了如何使用Keras的SavedModel API来保存和加载模型。高级API包括`model.save`和`tf.keras.models.load_model`,而低级API涉及`tf.saved_model.save`和`tf.saved_model.load`。对于没有明确定义输入的模型,应使用低级API。保存和加载的策略可以根据需求灵活选择。
摘要由CSDN通过智能技术生成

概述

在训练期间一般需要保存和加载模型。有两组用于保存和加载 Keras 模型的 API:高级 API 和低级 API。本教程演示了在使用 tf.distribute.Strategy 时如何使用 SavedModel API。要了解 SavedModel 和序列化的相关概况,请参阅保存的模型指南和 Keras 模型序列化指南。让我们从一个简单的示例开始:

导入依赖项:

import tensorflow_datasets as tfdsimport tensorflow as tftfds.disable_progress_bar()

使用 tf.distribute.Strategy 准备数据和模型:

mirrored_strategy = tf.distribute.MirroredStrategy()def get_data():  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)  mnist_train, mnist_test = datasets['train'], datasets['test']  BUFFER_SIZE = 10000  BATCH_SIZE_PER_REPLICA = 64  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync  def scale(image, label):    image = tf.cast(image, tf.float32)    image /= 255    return image, label  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)  return train_dataset, eval_datasetdef get_model():  with mirrored_strategy.scope():    model = tf.keras.Sequential([        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),        tf.keras.layers.MaxPooling2D(),        tf.keras.layers.Flatten(),        tf.keras.layers.Dense(64, activation='relu'),        tf.keras.layers.Dense(10)    ])    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),                  optimizer=tf.keras.optimizers.Adam(),                  metrics=['accuracy'])    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

    训练模型:

model = get_model()train_dataset, eval_dataset = get_data()model.fit(train_dataset, epochs=2)
Epoch 1/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 4s 5ms/step - loss: 0.1971 - accuracy: 0.9421
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0662 - accuracy: 0.9801

保存和加载模型

现在,您已经有一个简单的模型可供使用,让我们了解一下如何保存/加载 API。有两组可用的 API:

  • 高级 Keras model.save 和 tf.keras.models.load_model

  • 低级 tf.saved_model.save 和 tf.saved_model.load

Keras API

以下为使用 Keras API 保存和加载模型的示例:

keras_model_path = "/tmp/keras_save"model.save(keras_model_path)  # save() should be called out of strategy scope
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets


    恢复无 tf.distribute.Strategy 的模型:

restored_keras_model = tf.keras.models.load_model(keras_model_path)restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0480 - accuracy: 0.0990
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0334 - accuracy: 0.0989

        恢复模型后,您可以继续在它上面进行训练,甚至无需再次调用 compile(),因为在保存之前已经对其进行了编译。模型以 TensorFlow 的标准 SavedModel proto 格式保存。有关更多信息,请参阅 saved_model 格式指南。

现在,加载模型并使用 tf.distribute.Strategy 进行训练:

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")with another_strategy.scope():  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0481 - accuracy: 0.0989
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0329 - accuracy: 0.0990

        如您所见, tf.distribute.Strategy 可以按预期进行加载。此处使用的策略不必与保存前所用策略相同。

tf.saved_model API

现在,让我们看一下较低级别的 API。保存模型与 Keras API 类似:

model = get_model()  # get a fresh modelsaved_model_path = "/tmp/tf_save"tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

        可以使用 tf.saved_model.load() 进行加载。但是,由于该 API 级别较低(因此用例范围更广泛),所以不会返回 Keras 模型。相反,它返回一个对象,其中包含可用于进行推断的函数。例如:

DEFAULT_FUNCTION_KEY = "serving_default"loaded = tf.saved_model.load(saved_model_path)inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

    加载的对象可能包含多个函数,每个函数与一个键关联。"serving_default" 是使用已保存的 Keras 模型的推断函数的默认键。要使用此函数进行推断,请运行以下代码:

predict_dataset = eval_dataset.map(lambda image, label: image)for batch in predict_dataset.take(1):  print(inference_func(batch))
{'dense_3': array([[ 0.17218862,  0.07492599, -0.0548683 ,  0.03503785, -0.03743191,
-0.05301537, 0.01267872, -0.02870197, -0.33800656, 0.17991678],
[ 0.12937182, -0.21557797, -0.09474514, 0.39076763, -0.22147779,
-0.1787742 , 0.2154337 , 0.00788027, -0.14960325, 0.43123117],
[ 0.04755233, -0.20264567, -0.17308846, 0.19781005, -0.11123425,
-0.4295108 , 0.05442019, 0.01459119, -0.17129104, 0.04688327],
[ 0.09866484, 0.01627818, -0.08671301, 0.05742932, -0.20312837,
-0.38836166, -0.06952551, 0.05141062, -0.03084616, 0.05498504],
[ 0.00565811, -0.04239772, 0.04898138, 0.06162139, -0.16708252,
-0.12976539, -0.00474121, 0.05431085, -0.14715545, 0.07582194],
[ 0.17589626, 0.19629489, -0.2076093 , 0.02031662, -0.1619812 ,
-0.24300966, -0.0310282 , -0.00850905, -0.18514219, 0.23665032],
[-0.02653 , -0.17737214, -0.24494407, 0.20125583, -0.17153463,
-0.18641792, 0.11408111, 0.01489197, -0.099539 , 0.41159016],
[ 0.1903163 , 0.1697292 , -0.14116906, 0.1588785 , -0.04286646,
-0.19863203, -0.04836996, -0.00679918, -0.14634813, 0.14979276],
[ 0.12109621, 0.03313948, -0.1955429 , 0.23528968, -0.12369496,
-0.20725062, 0.06024174, 0.05078189, -0.158943 , 0.16846842],
[ 0.16227934, 0.06379895, -0.08847713, 0.08261362, -0.03925761,
-0.17770812, -0.043965 , 0.02072081, -0.07430968, 0.05749936],
[ 0.05508922, -0.14091367, -0.1887006 , 0.12903523, -0.13182093,
-0.11879301, 0.20175044, 0.11686974, -0.1616871 , 0.2226192 ],
[ 0.18285918, -0.01880376, -0.15778637, 0.04477023, -0.22364017,
-0.23864916, -0.06328501, 0.04380857, -0.04448643, 0.40406597],
[ 0.04721744, 0.06619421, -0.10837474, 0.1292499 , -0.17490903,
-0.17313394, -0.06603841, 0.15658481, -0.09657097, -0.04059617],
[-0.04412666, 0.02258963, 0.08539917, 0.2561011 , -0.18279126,
-0.2519745 , -0.00787598, 0.08598025, -0.21961546, 0.10189874],
[ 0.05089861, 0.06746367, -0.13205 , 0.09160744, -0.30171782,
-0.25160635, 0.08317091, 0.03015741, -0.10570806, 0.28686398],
[ 0.13625176, -0.109529 , 0.04985618, 0.08199271, -0.24280871,
-0.22908798, 0.17737128, 0.09937412, -0.31234092, 0.2290439 ],
[ 0.13812706, 0.10425253, 0.0128724 , 0.12191941, -0.09126505,
-0.13897963, -0.17568447, 0.16489705, -0.26533198, 0.06911667],
[ 0.16982701, 0.087276 , -0.17102191, 0.06745699, -0.06239565,
-0.17226742, -0.02450407, 0.10939141, -0.13510445, 0.04026298],
[-0.05762933, 0.03908077, 0.0729831 , 0.12001946, -0.12699135,
-0.37191632, -0.10294843, 0.1815257 , -0.10121268, 0.06880292],
[ 0.07649058, -0.03354908, -0.06362928, -0.00831218, -0.24217641,
-0.11137463, 0.01944396, 0.0310707 , 0.0093919 , 0.34353036],
[ 0.16107717, -0.04705916, -0.14095825, 0.05297582, -0.1485554 ,
-0.12321693, 0.07225874, 0.07695273, -0.17055047, 0.22460693],
[ 0.02565719, -0.05495968, -0.11961621, 0.03014402, -0.1645109 ,
-0.26333475, 0.07536604, 0.04426918, -0.12448484, 0.04142715],
[ 0.02295595, 0.01484419, -0.28111714, 0.05291839, -0.09908111,
-0.22002876, 0.00388122, 0.06801579, -0.03227042, 0.04201593],
[ 0.01293404, -0.15113808, -0.05814568, 0.29754263, -0.13849238,
-0.02268202, 0.16958144, 0.12881759, -0.13463333, 0.3364867 ],
[ 0.19805974, -0.01798259, -0.12835501, 0.26842418, -0.04154617,
-0.19442351, -0.08115683, 0.08586816, 0.00582654, 0.04328927],
[ 0.09159922, 0.12617984, -0.15028486, 0.23344447, -0.06932314,
-0.1483246 , -0.02017963, 0.03262286, -0.2800941 , 0.18364596],
[ 0.1528 , 0.13280275, -0.09938447, 0.03614349, -0.1096218 ,
-0.19335787, -0.04933339, -0.02397237, -0.13356304, -0.01165973],
[ 0.13618907, 0.14891617, -0.16118397, 0.10435603, -0.1831438 ,
-0.16405147, -0.14186187, 0.12581114, -0.15762964, 0.13493878],
[ 0.05534358, -0.0916103 , 0.0352111 , 0.0020496 , -0.19224274,
-0.17663556, 0.08702807, -0.08016825, -0.14833373, 0.10739949],
[ 0.02660379, -0.04472145, 0.01165188, 0.0219909 , -0.16059823,
-0.26817566, -0.09790543, 0.10905766, -0.01595427, 0.304615 ],
[ 0.08248052, -0.09962849, -0.02325149, 0.04280585, -0.20835052,
-0.2023199 , -0.0130603 , 0.07936736, 0.0494375 , 0.27143508],
[ 0.00310345, 0.04583906, -0.20415008, 0.1876276 , -0.06600557,
-0.19580218, -0.02222047, 0.07650423, -0.08899002, 0.10885157],
[ 0.0783096 , -0.01651647, -0.09479928, 0.07058451, -0.14990349,
-0.33366078, 0.0564964 , 0.01118498, -0.14589244, 0.22603557],
[ 0.04565446, 0.05590308, -0.02989801, -0.07578284, -0.09796432,
-0.20807403, -0.00954358, 0.02622838, -0.10276475, -0.05590656],
[ 0.07286316, 0.01376749, -0.18262148, 0.28560585, -0.18269306,
-0.06166455, 0.12229253, 0.11880912, -0.08595768, 0.17080015],
[ 0.12635507, -0.0836257 , 0.03501946, 0.30507207, -0.34584454,
-0.29186884, 0.26327768, 0.18378039, -0.09220086, 0.16707191],
[ 0.11742169, 0.02937749, -0.16469768, 0.31997636, -0.1280521 ,
-0.17700416, 0.05593231, 0.05017062, -0.31535 , 0.15465745],
[ 0.08975917, 0.01203279, 0.09783987, 0.06205256, -0.05648104,
-0.27429107, -0.12651348, 0.09195078, -0.2890005 , 0.08270936],
[ 0.09477694, 0.10097383, -0.05783979, 0.11597094, -0.05375554,
-0.04229444, -0.09689695, 0.08121311, -0.05716637, 0.09075539],
[-0.04117738, -0.06426363, -0.0629988 , 0.00692648, -0.30303234,
-0.28447956, -0.01935545, 0.159902 , -0.10399745, 0.17079492],
[-0.01080875, -0.04450692, -0.19694453, 0.15313052, -0.11790004,
-0.21164687, 0.16064486, 0.05443045, 0.04431828, 0.18498638],
[ 0.16398555, 0.21772492, -0.03592323, 0.15181649, -0.02455682,
-0.28267485, -0.12445807, 0.17047536, -0.19300474, -0.01467199],
[ 0.04904355, -0.0152067 , 0.09667489, -0.01841408, -0.08439851,
-0.2905228 , -0.0541675 , 0.07489735, -0.13492545, 0.1839124 ],
[ 0.2369909 , 0.08534706, -0.12017098, 0.04527019, -0.05781246,
-0.1196178 , -0.09442404, 0.01685349, -0.26979008, 0.17579612],
[ 0.04441281, -0.09139308, 0.00063404, 0.02085789, -0.17478338,
-0.1746104 , 0.21254838, 0.07575508, -0.19009903, 0.26038024],
[ 0.23913413, 0.13267268, -0.11951514, 0.13184579, -0.11442515,
-0.1563474 , -0.13503158, 0.1639925 , -0.11313978, 0.05294855],
[ 0.11768216, 0.12213368, -0.00641227, 0.1983034 , -0.10263431,
-0.10918278, -0.06888436, 0.26294842, -0.1041921 , 0.09731302],
[ 0.16183744, -0.14602011, -0.17195675, 0.1428874 , -0.26739907,
-0.3048862 , 0.06860068, 0.03065268, -0.13347332, 0.4117231 ],
[-0.02206257, 0.00734324, 0.003649 , 0.12295016, -0.22801307,
-0.23414296, -0.03367008, 0.11127277, -0.01726604, -0.0447302 ],
[ 0.10106434, 0.09055474, -0.12789255, 0.1377592 , -0.05564225,
-0.21510065, -0.09061419, -0.0219887 , -0.14411387, -0.03950592],
[ 0.12847602, -0.09453006, -0.04503661, 0.27597424, -0.17524761,
-0.05134012, 0.16526361, 0.08649909, -0.22461002, 0.45229536],
[ 0.04311011, 0.09949236, -0.04975891, 0.22421105, -0.12030718,
-0.09846736, -0.1408607 , 0.2384947 , -0.21582088, 0.01464934],
[-0.03788627, 0.04636163, 0.07747708, 0.0814044 , -0.12896554,
-0.31223392, -0.0578138 , 0.1859979 , -0.10911787, 0.15140374],
[ 0.08929176, -0.02551255, -0.06947158, 0.25500187, -0.18166143,
-0.1110489 , 0.0658811 , 0.23209906, -0.00346252, 0.27463445],
[ 0.12721871, -0.05336493, -0.01648436, 0.23337078, -0.22428553,
-0.17424905, 0.03487325, 0.28687072, 0.04055911, 0.30594033],
[ 0.18656036, -0.00513786, -0.16282284, 0.02530107, -0.17092519,
-0.24259233, 0.05227455, 0.19966123, -0.28181344, 0.14443643],
[ 0.02111852, -0.04639132, -0.01641255, 0.20416623, -0.11734181,
-0.08085347, 0.13685697, 0.10490854, -0.09023371, 0.32988763],
[ 0.06382357, 0.02803485, 0.03532831, 0.07898249, -0.10290041,
-0.2603921 , -0.03376516, 0.09166428, -0.14019875, 0.19503292],
[ 0.15105441, 0.0064583 , -0.1603775 , 0.16818096, -0.22179885,
-0.36698502, 0.12694073, -0.1294238 , -0.21702135, 0.34743598],
[ 0.11475793, -0.08016841, -0.19020993, 0.27748483, -0.13198294,
-0.22254312, 0.19926155, 0.19124901, -0.08933976, 0.25242418],
[ 0.09380357, -0.02989926, -0.01782445, 0.00312767, -0.02519768,
-0.43802148, -0.00290839, 0.04753356, -0.02965541, 0.10304467],
[ 0.20286047, -0.07675526, -0.03217752, 0.17366095, -0.13799758,
-0.27491322, 0.00279245, 0.14233288, -0.05951798, 0.36937428],
[ 0.01445094, -0.07265921, 0.10096341, 0.17594802, -0.17472097,
-0.2958681 , 0.0036519 , 0.03119059, -0.2027646 , -0.01793122],
[-0.02391969, -0.10441571, -0.00624696, 0.06563509, -0.14965585,
-0.3743796 , 0.0422266 , 0.04684277, 0.05023851, -0.07264638]],
dtype=float32)>}

    您还可以采用分布式方式加载和进行推断:

another_strategy = tf.distribute.MirroredStrategy()with another_strategy.scope():  loaded = tf.saved_model.load(saved_model_path)  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]  dist_predict_dataset = another_strategy.experimental_distribute_dataset(      predict_dataset)  # Calling the function in a distributed manner  for batch in dist_predict_dataset:    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

         调用已恢复的函数只是基于已保存模型的前向传递(预测)。如果您想继续训练加载的函数,或者将加载的函数嵌入到更大的模型中,应如何操作?通常的做法是将此加载对象包装到 Keras 层以实现此目的。幸运的是,TF Hub 为此提供了 hub.KerasLayer,如下所示:

import tensorflow_hub as hubdef build_model(loaded):  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')  # Wrap what's loaded to a KerasLayer  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)  model = tf.keras.Model(x, keras_layer)  return modelanother_strategy = tf.distribute.MirroredStrategy()with another_strategy.scope():  loaded = tf.saved_model.load(saved_model_path)  model = build_model(loaded)  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),                optimizer=tf.keras.optimizers.Adam(),                metrics=['accuracy'])  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.2059 - accuracy: 0.9393
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0681 - accuracy: 0.9799

        如您所见,hub.KerasLayer 可将从 tf.saved_model.load() 加载回的结果包装到可供构建其他模型的 Keras 层。这对于迁移学习非常实用。

我应使用哪种 API?

    对于保存,如果您使用的是 Keras 模型,那么始终建议使用 Keras 的 model.save() API。如果您所保存的不是 Keras 模型,那么您只能选择使用较低级的 API。

    对于加载,使用哪种 API 取决于您要从加载的 API 中获得什么。如果您无法或不想获取 Keras 模型,请使用 tf.saved_model.load()。否则,请使用 tf.keras.models.load_model()。请注意,只有保存 Keras 模型后,才能恢复 Keras 模型。

    可以混合使用 API。您可以使用 model.save 保存 Keras 模型,并使用低级 API tf.saved_model.load 加载非 Keras 模型。

model = get_model()# Saving the model using Keras's save() APImodel.save(keras_model_path) another_strategy = tf.distribute.MirroredStrategy()# Loading the model using lower level APIwith another_strategy.scope():  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

警告

        有一种特殊情况,您的 Keras 模型没有明确定义的输入。例如,可以创建没有任何输入形状的序贯模型 (Sequential([Dense(3), ...])。子类化模型在初始化后也没有明确定义的输入。在这种情况下,在保存和加载时都应坚持使用较低级别的 API,否则会出现错误。

        要检查您的模型是否具有明确定义的输入,只需检查 model.inputs 是否为 None。如果非 None,则一切正常。在 .fit.evaluate.predict 中使用模型,或调用模型 (model(inputs)) 时,输入形状将自动定义。

以下为示例:

class SubclassedModel(tf.keras.Model):  output_name = 'output_layer'  def __init__(self):    super(SubclassedModel, self).__init__()    self._dense_layer = tf.keras.layers.Dense(        5, dtype=tf.dtypes.float32, name=self.output_name)  def call(self, inputs):    return self._dense_layer(inputs)my_model = SubclassedModel()# my_model.save(keras_model_path)  # ERROR! tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.subclassedmodel object at>, because it is not built.
Warning:tensorflow:Skipping full serialization of Keras layer <__main__.subclassedmodel object at>, because it is not built.
Warning:tensorflow:Skipping full serialization of Keras layer , because it is not built.
Warning:tensorflow:Skipping full serialization of Keras layer , because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

完毕!!!!!!!!!!!!

4a08083526c9b2c46cb759e13097d1ca.png

点赞?????在看

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值