https://www.tensorflow.org/federated/
-
Federated Learning (FL) API
該層提供了一組高階接口,使開發者能夠將包含的聯合訓練和評估實現應用於現有的 TensorFlow 模型。 -
Federated Core (FC) API
該系統的核心是一組較低階接口,可以通過在強類型函數式編程環境中結合使用 TensorFlow 與分布式通信運算符,簡潔地表達新的聯合算法。這一層也是我們構建聯合學習的基礎。 -
借助 TFF,開發者能夠以聲明方式表達聯合計算,從而將它們部署到不同的運行時環境中。TFF 包含一個用於實驗的單機模擬運行時。請訪問相關 教程,並親自試用!
from six.moves import range
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow_federated.python.examples import mnist
tf.compat.v1.enable_v2_behavior()
# Load simulation data.
source, _ = tff.simulation.datasets.emnist.load_data()
def client_data(n):
dataset = source.create_tf_dataset_for_client(source.client_ids[n])
return mnist.keras_dataset_from_emnist(dataset).repeat(10).batch(20)
# Pick a subset of client devices to participate in training.
train_data = [client_data(n) for n in range(3)]
# Grab a single batch of data so that TFF knows what data looks like.
sample_batch = tf.nest.map_structure(
lambda x: x.numpy(), iter(train_data[0]).next())
# Wrap a Keras model for use with TFF.
def model_fn():
return tff.learning.from_compiled_keras_model(
mnist.create_simple_keras_model(), sample_batch)
# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.build_federated_averaging_process(model_fn)
state = trainer.initialize()
for _ in range(5):
state, metrics = trainer.next(state, train_data)
print (metrics.loss)