# 需要导入模块: from keras import models [as 别名]
# 或者: from keras.models import save_model [as 别名]
def _patch_io_calls(Network, Sequential, keras_saving):
try:
if Sequential is not None:
Sequential._updated_config = _patched_call(Sequential._updated_config,
PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
# noinspection PyUnresolvedReferences
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
PatchKerasModelIO._from_config))
else:
Sequential.from_config = _patched_call(Sequential.from_config, PatchKerasModelIO._from_config)
if Network is not None:
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
if hasattr(Sequential.from_config, '__func__'):
# noinspection PyUnresolvedReferences
Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
PatchKerasModelIO._from_config))
else:
Network.from_config = _patched_call(Network.from_config, PatchKerasModelIO._from_config)
Network.save = _patched_call(Network.save, PatchKerasModelIO._save)
Network.save_weights = _patched_call(Network.save_weights, PatchKerasModelIO._save_weights)
Network.load_weights = _patched_call(Network.load_weights, PatchKerasModelIO._load_weights)
if keras_saving is not None:
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
except Exception as ex:
LoggerRoot.get_base_logger(TensorflowBinding).warning(str(ex))