前言
在用TFF跑demo时候,按照官方文档的写法
https://github.com/tensorflow/federated/blob/master/docs/tutorials/federated_learning_for_image_classification.ipynb直接使用会因为连接超时而报错,所以先在本地下载好数据集放入自己创立的cache路径。
原本代码使用:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
def load_data(only_digits=True, cache_dir=None)
Args:
only_digits: (Optional) whether to only include examples that are from the digits [0-9] classes.
If `False`, includes lower and upper case characters, for a total of 62 class labels.
cache_dir: (Optional) directory to cache the downloaded file.
If `None`, caches in Keras' default cache directory.
第一种方法
加入下载数据集的文件路径,调用:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(cache_dir = '/home/cqx/PycharmProjects/cache/fed_emnist_digitsonly')
成功:
第二种方法
根据报错的地方选择直接修改接口的path:
~/anaconda3/lib/python3.8/site-packages/tensorflow_federated/python/simulation/datasets/emnist.py
load_data(only_digits, cache_dir)
# 第105行的origin改为自己下载的数据集(fed_emnist_digitsonly.tar.bz2)路径
# origin='https://storage.googleapis.com/tff-datasets-public/' + filename,
filename = fileprefix + '.tar.bz2'
path = tf.keras.utils.get_file(
filename,
origin='/home/cqx/PycharmProjects/cache/' + filename,
file_hash=sha256,
hash_algorithm='sha256',
extract=True,
archive_format='tar',
cache_dir=cache_dir)
接着再次运行:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
len(emnist_train.client_ids)
输出如下,和demo的输出长度一致。