在Keras中使用TPU
从Tensorflow 2.1开始,通过Keras API支持TPU。Keras支持适用于TPU和TPU盒。这是一个适用于TPU,GPU和CPU的示例:
# TPU detection
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
tpu = None
# TPUStrategy for distributed training
if tpu:
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
else: # default strategy that works on CPU and single GPU
strategy = tf.distribute.get_strategy()
# use TPUStrategy scope to define model
with strategy.scope():
model = tf.keras.Sequential( ... )
model. compile( ... )
# train model normally on a tf.data.Dataset
model.fit(training_dataset, epochs=EPOCHS, steps_per_epoch=...)
在此代码段中:
尽管有多种方法可以在Tensorflow模型中加载数据,但对于TPU,需要使用tf.data.DatasetAPI。
TPU速度非常快,并且在它们上运行时,提取数据通常会成为瓶颈。
其他的和tensorflow和keras都一样,不过只是模型创建在scope下不同而已。
接下来开始享受tpu飞一般的速度吧