报错信息
Traceback (most recent call last):
File "D:/***.py", line 112, in <module>
train()
File "D:/***.py", line 102, in train
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
File "E:\Anaconda3\lib\site-packages\tensorflow_federated\python\learning\federated_averaging.py", line 229, in build_federated_averaging_process
model_update_aggregation_factory=model_update_aggregation_factory)
File "E:\Anaconda3\lib\site-packages\tensorflow_federated\python\learning\framework\optimizer_utils.py", line 610, in build_model_delta_optimizer_process
model_weights_type = model_utils.weights_type_from_model(model_fn)
File "E:\Anaconda3\lib\site-packages\tensorflow_federated\python\learning\model_utils.py", line 100, in weights_type_from_model
model = model()
File "D:/Program Files/PyCharm 2019.2/GraduationDesign/tff_test.py", line 94, in model_fn
loss=tf.keras.losses.mean_squared_error(),
File "E:\Anaconda3\lib\site-packages\tensorflow\python\util\dispatch.py", line 201, in wrapper
return target(*args, **kwargs)
TypeError: mean_squared_error() missing 2 required positional arguments: 'y_true' and 'y_pred'
解决
报错的代码为:
return tff.learning.from_keras_model(
model,
input_spec=train_data[0].element_spec,
loss=tf.keras.losses.mean_squared_error(),
metrics=[tf.keras.metrics.mean_absolute_percentage_error()])
其中定义损失函数:
loss=tf.keras.losses.mean_squared_error()
这里提示缺少参数。
修改
loss=tf.keras.losses.MeanSquaredError(),
metrics=[tf.keras.metrics.MeanAbsolutePercentageError()])