tff.simulation.datasets.emnist.load_data加载本地数据集

tff.simulation.datasets.emnist.load_data加载超时

前言

在用TFF跑demo时候,按照官方文档的写法
https://github.com/tensorflow/federated/blob/master/docs/tutorials/federated_learning_for_image_classification.ipynb直接使用会因为连接超时而报错,所以先在本地下载好数据集放入自己创立的cache路径。
原本代码使用:

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

报错内容:
在这里插入图片描述
查看相关文档
手册接口定义:

def load_data(only_digits=True, cache_dir=None)
Args:
only_digits: (Optional) whether to only include examples that are from the digits [0-9] classes. 
If `False`, includes lower and upper case characters, for a total of 62 class labels.
cache_dir: (Optional) directory to cache the downloaded file.
 If `None`, caches in Keras' default cache directory.

第一种方法

加入下载数据集的文件路径,调用:

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(cache_dir = '/home/cqx/PycharmProjects/cache/fed_emnist_digitsonly')

成功:
第一种方法成功调用

第二种方法

根据报错的地方选择直接修改接口的path:
~/anaconda3/lib/python3.8/site-packages/tensorflow_federated/python/simulation/datasets/emnist.py

 load_data(only_digits, cache_dir)
  # 第105行的origin改为自己下载的数据集(fed_emnist_digitsonly.tar.bz2)路径
  # origin='https://storage.googleapis.com/tff-datasets-public/' + filename,
  filename = fileprefix + '.tar.bz2'
  path = tf.keras.utils.get_file(
      filename,
      origin='/home/cqx/PycharmProjects/cache/' + filename,
      file_hash=sha256,
      hash_algorithm='sha256',
      extract=True,
      archive_format='tar',
      cache_dir=cache_dir)

接着再次运行:

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
len(emnist_train.client_ids)

输出如下,和demo的输出长度一致。
输出成功

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值