首先檢測TPU存在:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver() #如果先前設置好了TPU_NAME環境變量,不需要再給參數.
tpu的返回值為1 or 0 ,1則檢測到了TPU.
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
with strategy.scope():
#define a model
#compile it
#train it
因為這目前還是一個實驗功能,代碼實現可能過一段時間就變了,看官方給的通知吧.
