使用emnist数据集进行简单的FedAvg算法
import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
# 测试tff是否安装成功
# print(tff.federated_computation(lambda: 'Hello World')())
# 加载数据集
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
cache_dir='/home/cqx/PycharmProjects/cache/fed_emnist_digitsonly')
# 查看数据集长度和结构
print(len(emnist_train.client_ids))
print(emnist_train.element_type_structure)
# 给指定客户端创造数据集 返回值tf.data.Dataset` object.
example_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0])
# iter迭代,Iterator对象可以被next()函数调用并不断返回下一个数据,直到没有数据时抛出StopIteration错误。
example_element = next(iter(example_dataset))
print(example_element['label'].numpy())
# 使用数据集转换完成预处理。
# 在这里,我们将图像拉平到数组中,将各个示例打乱,并将它们组织成批次,然后重命名特征
# 客户端数目
NUM_CLIENTS = 10
# 训练次数
NUM_EPOCHS = 5
# 批次大小
BATCH_SIZE = 20
# 随机打乱
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10
def preprocess(dataset):
def batch_format_fn(element):
"""Flatten a batch `pixels` and return the features as an `OrderedDict`."""
return collections.OrderedDict(
x=tf.reshape(element['pixels'], [-1, 784]),
y=tf.reshape(element['label'], [-1, 1]))
# repeat(count) 将数据重复count次
# shuffle(shuffleSize,seed)
# dataset.examples.batch(20).prefetch(2) 预取(2批,每批20个例子)
return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)
preprocessed_example_dataset = preprocess(example_dataset)
# a = [24, 76, "ab"]
# tf.nest.map_structure(lambda p: p * 2, a)
# [48, 152, 'abab']
sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
next(iter(preprocessed_example_dataset)))
print(len(sample_batch['y']))
print(sample_batch)
# 从给定的一组用户作为一轮培训或评估的输入.
def make_federated_data(client_data, client_ids):
return [
preprocess(client_data.create_tf_dataset_for_client(x))
for x in client_ids
]
# 构造客户端数据
sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]
federated_train_data = make_federated_data(emnist_train, sample_clients)
print(f'Number of client datasets: {len(federated_train_data)}')
print(f'First dataset: {federated_train_data[0]}')
# 建立网络模型
def create_keras_model():
return tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
tf.keras.layers.Softmax(),
])
# 为了将任何模型与 TFF 一起使用,需要将其包装在 tff.learning.Model 接口的实例中
# 将模型和示例数据批处理作为参数
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=preprocessed_example_dataset.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
# 联邦平均算法实现
# client优化器仅用于计算每个客户端上的本地模型更新。
# server优化器将平均更新应用于全局模型更新
training_process = tff.learning.algorithms.build_weighted_fed_avg(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
# 输出服务器上的FedAVG进程。
print('可视化服务器上的FedAVG进程')
print(training_process.initialize.type_signature.formatted_representation())
# 初始化服务器状态
train_state = training_process.initialize()
# next
# 发送服务器状态 (包括模型参数)给客户,
# 在他们的设备上进行训练本地数据,收集和平均模型更新,并生成新的更新
# 发送给服务器,更新全局模型
# 进行一次训练
# result = training_process.next(train_state, federated_train_data)
# train_state = result.state
# train_metrics = result.metrics
# print('round 1, metrics={}'.format(train_metrics))
# 训练多轮
NUM_ROUNDS = 11
for round_num in range(1, NUM_ROUNDS):
result = training_process.next(train_state, federated_train_data)
train_state = result.state
train_metrics = result.metrics
print('round {:2d}, metrics={}'.format(round_num, train_metrics))
结果展示
3383
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
1
20
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
[1],
[5],
[7],
[1],
[7],
[7],
[1],
[4],
[7],
[4],
[2],
[2],
[5],
[4],
[1],
[1],
[0],
[0],
[9]]))])
Number of client datasets: 10
First dataset: <PrefetchDataset element_spec=OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))])>
可视化服务器上的FedAVG进程
( -> <
global_model_weights=<
trainable=<
float32[784,10],
float32[10]
>,
non_trainable=<>
>,
distributor=<>,
client_work=<>,
aggregator=<
value_sum_process=<>,
weight_sum_process=<>
>,
finalizer=<
int64
>
>@SERVER)
round 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12345679), ('loss', 3.1193738), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.13518518), ('loss', 2.9834726), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.14382716), ('loss', 2.8616652), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.17407407), ('loss', 2.7957022), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.19917695), ('loss', 2.6146567), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.21975309), ('loss', 2.5297604), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.2409465), ('loss', 2.4053502), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.2611111), ('loss', 2.3153887), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.30823046), ('loss', 2.1240258), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.33312756), ('loss', 2.1164267), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])Process finished with exit code 0