联邦学习的修仙之路_3

目录

4. TensorFlow Federated 实现Minst数据集识别

4.1  导入并熟悉Minst数据集

4.2 处理Mnist数据集

4.3 准备/生成Model(Keras)

4.4 开始训练

4.5 Evaluation

4.6 结语


4. TensorFlow Federated 实现Minst数据集识别

前两篇博客回顾 重新学习 了tensorflow 结构以及 tff 的两层API,讲道理官网对这两层的讲解真的有够迷。我之前一直纠结这个‘两层’的含义,现在稍微懂了一点:其中高层的FLearning是指不关注模型底层架构,把现有的模型拿来之后做简单的转换;而FCore则指可以重新创建或修改底层模型架构的API。

这篇博文的出发点是官方文档给的第一个TFF案例,(我学代码比较习惯从直接看一篇跑通的代码开始,不然一直学理论也搞不出个所以然来)。不得不说,这篇官方tutorial虽然是英文文档,但讲的非常清晰(比前几篇指南好了N个层次),相信只要不是英语过于苦手的朋友都可以读的懂这篇官方文档(shown as below):

https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification

写在最前面:官方给的例子主体都是由FLearning的高层API撰写的(所以在没有特殊标注之前,也基于官网同步)。

4.1  导入并熟悉Minst数据集

和常规的机器学习步骤相同,这里的第一部也需要进行数据集导入,处理:

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

这里使用内置数据集Mnist进行导入,导入之后可以分别使用如下函数进行查看长度,类型及打印(第一个图像),具体代码原文中展示十分详细,在这里不予赘述;仅列出几个关键字及作用。

element_type_structureAttributes, The element type information of the client datasets. elements returned by datasets in this ClientData object.

create_tf_dataset_for_client

Method,
create_tf_dataset_for_client(
   client_id: str
) 

-> tf.data.Dataset
from matplotlib import pyplot as plt plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal') plt.grid(False) _ = plt.show()

用于画图的固定套组,老朋友了。

官方文档里使用(包括但不限于)上述关键字分别查看了数据的数量,数据的标签和实际图像(label, pixels)但这些内容都属于帮助我们熟悉数据集,真正需要用于tff过程的只有最上面导入数据集的那一句代码。

在此之后,原文还探索了每个用户的书写特征(根据id并统计出条形图)并计算mean产生图像(这里证明了联邦学习所处在的非独立同分布情况)因为这部分也属于探索数据集,就不再分析了;可以查原文,讲的比较清晰。

4.2 处理Mnist数据集

在导入数据集之后,按正常流程将数据集进行处理数据集:将其拉平,重复,打乱。值得注意的是,这里将处理数据的过程(囊括到一个方法里),然后通过调用传参进行调用。

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

这里可以看到,前面一段定义了一些超参数;第二段实现图像像素拉平;第三段实现数据的重复,打乱;这部分代码可以用如下代码进行检验:

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch

这里啰嗦一句next和iter;list、tuple等都是可迭代对象,我们可以通过iter()函数获取这些可迭代对象的迭代器。然后我们可以对获取到的迭代器不断使⽤next()函数来获取下⼀条数据。

在准备好这些之后,分出是个客户端并为这十个客户端分配数据集:

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

4.3 准备/生成Model(Keras)

首先利用Keras构建模型:

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

结构比较简单,输入层 + 稠密层 + 激活层。构造好模型之后将其转换 tff 实例,这步也是高阶API的精髓:

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

4.4 开始训练

在构建好模型之后,可以开始进行训练。这里使用了build_federated_averaging_process来创建了一个交互的训练过程。这里定义了两个学习率:客户端和服务器,其中前者负责本地的更新而后者作用于avg。

可以用 type_signature 来打印出函数签名(),(函数签名由函数原型组成。它告诉你的是关于函数的一般信息,它的名称,参数,它的范围以及其他杂项信息。可以确定传入的参数是符合要求的)

接着将iterative_process进行initialize得到state,即:

state = iterative_process.initialize()

值得注意的是这里的state并不是指 ‘状态’,根据官方的解释:The initialize_fn function must return an object which is expected as input to and returned by the next_fn function. By convention, we refer to this object as state.

再用两个参数去接 state 和 metric 就可以开始训练优化过程:

state, metrics = iterative_process.next(state, federated_train_data)

再使用for循环进行循环训练优化,就可以了:

for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))

4.5 Evaluation

在进行了一定轮次的训练之后(或者在准确度达到一定的程度之后)可以停止训练并开始进行模型的评估。原文说也是在防止过拟合。直接调用并创建得到实例:

evaluation = tff.learning.build_federated_evaluation(MnistModel) 

在得到实例之后,得到 train_metrics,再str()打印即可;后续如果想测试准确度,再重复使用evaluation即可。

train_metrics = evaluation(state.model, federated_train_data)

str(train_metrics)

4.6 结语

这篇博客顺理了一遍FL高阶API的Mnist的识别;可以发现在这个项目里tff的东西其实比较少,大部分都是keras的东西还有数据处理。这里可能也就体现出来了高阶API的特点。下一篇准备搞一下低阶API,因为最开始也是从低阶开始学这个项目的,毕竟是原汁原味的TFF 不是 \( ̄▽ ̄)/

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值