BUG
在使用Keras训练模型时,在每个epoch完成后save_model时会报错 “AttributeError: 'NoneType' object has no attribute 'update'”
具体异常打印信息如下,主要原因是模型中有自定义的class,Keras不知道怎么进行deep_copy()
File "train.py", line 88, in <module>
build_model()
File "train.py", line 80, in build_model
CSVLogger(log_path),
File "/usr/python/lib/python3.5/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/usr/python/lib/python3.5/site-packages/keras/engine/training.py", line 2268, in fit_generator
callbacks.on_epoch_end(epoch, epoch_logs)
File "/usr/python/lib/python3.5/site-packages/keras/callbacks.py", line 77, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "/usr/python/lib/python3.5/site-packages/keras/callbacks.py", line 447, in on_epoch_end
self.model.save(filepath, overwrite=True)
File "/usr/python/lib/python3.5/site-packages/keras/engine/topology.py", line 2591, in save
save_model(self, filepath, overwrite, include_optimizer)
File "/usr/python/lib/python3.5/site-packages/keras/models.py", line 126, in save_model
'config': model.get_config()
File "/usr/python/lib/python3.5/site-packages/keras/engine/topology.py", line 2432, in get_config
return copy.deepcopy(config)
解决办法
方法一
在回调函数callbacks中加入save_weights_only=True,加载模型时new Model(), 然后载入weights,这样避免deep_copy()
callbacks=[ModelCheckpoint(model_path, save_weights_only=True,
monitor='val_loss', mode='min', save_best_only=True)]
方法二
对无法deep_copy()的class,自定义复制方法。
class YourClass(object):
# ...
def __deepcopy__(self):
return Your_deep_copy()