【序列化】keras custom model to_json with NotImplementedError:

文章讨论了在使用Keras进行Spark分布式预测时遇到的模型序列化问题,特别是自定义模型继承Model类导致的NotImplementedError。解决方案是改用继承Layer类,并重写get_config方法,确保所有初始化参数在get_config中。此外,文章还提到了类继承模型与层模型在序列化、拓扑结构方面的区别。
摘要由CSDN通过智能技术生成

序列化的好处:spark分布式预测,方便模型存储

这块踩坑无数,翻阅无数中英文文档,其核心问题是,你自定义的model,是继承的什么类,如果是layer类,那么就不会存在model.to_json()等模型序列化时报错,这里包括to_yarm等序列化操作。报错如下:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-14-74d572bd2453> in <module>
----> 1 model.to_json()

/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in to_json(self, **kwargs)
   1207         A JSON string.
   1208     """
-> 1209     model_config = self._updated_config()
   1210     return json.dumps(
   1211         model_config, default=serialization.get_json_type, **kwargs)

/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in _updated_config(self)
   1185     from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
   1186 
-> 1187     config = self.get_config()
   1188     model_config = {
   1189         'class_name': self.__class__.__name__,

/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    885     if not self._is_graph_network:
    886       raise NotImplementedError
--> 887     return copy.deepcopy(get_network_config(self))
    888 
    889   @classmethod

/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
   1940           filtered_inbound_nodes.append(node_data)
   1941 
-> 1942     layer_config = serialize_layer_fn(layer)
   1943     layer_config['name'] = layer.name
   1944     layer_config['inbound_nodes'] = filtered_inbound_nodes

/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    138   if hasattr(instance, 'get_config'):
    139     return serialize_keras_class_and_config(instance.__class__.__name__,
--> 140                                             instance.get_config())
    141   if hasattr(instance, '__name__'):
    142     return instance.__name__

<ipython-input-11-8fbedb8f60e0> in get_config(self)
    179 
    180     def get_config(self):
--> 181         base_config = super().get_config().copy()
    182         config = {}
    183         config['name'] = self.name

/usr/local/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    884   def get_config(self):
    885     if not self._is_graph_network:
--> 886       raise NotImplementedError
    887     return copy.deepcopy(get_network_config(self))
    888 

NotImplementedError: 

那么核心问题是什么,在我们实现模型类时,如果继承的是model类,那么就会存在无法序列化问题

比如

class Intermitforecast(keras.Model):
    def __init__(self, pre_lens=28, window=180, name='IF', featurenum=36, **kwargs):
        super(Intermitforecast, self).__init__(name=name, **kwargs)

        self.pre_lens = pre_lens
        self.window = window
        self.units = 28
        self.ts_num = 1
        self.ft_num = featurenum

这里需要改变的就是,不继承Model类,继承Layer类。同时要记得在Layer类中,重写def get_config(self):。切记__init__中的所有初始化参数都要添加到get_config中

class Intermitforecast(Layer):
    def __init__(self, pre_lens=28, window=180, name='IF', featurenum=36, **kwargs):
        super(Intermitforecast, self).__init__(name=name, **kwargs)

        self.pre_lens = pre_lens
        self.window = window
        self.units = 28
        self.ts_num = 1
        self.ft_num = featurenum



………………………………


    def get_config(self):
        base_config = super().get_config().copy()
        config = {}
        config['name'] = self.name
        config.update({"pre_lens": self.pre_lens})
        config.update({"window": self.window})
        config.update({"units": self.units})
        config.update({"ts_num": self.ts_num})
        config.update({"ft_num": self.pre_lens})
        config.update({"w_reg": self.w_reg})
        config.update({"v_reg": self.v_reg})
        config.update({"hidden_units": self.hidden_units})
        return dict(list(base_config.items()) + list(config.items())) 

此处可以参考 解决NotImplementedError: Layer XX has arguments in `__init__` and therefore must override `get_config`_sinysama的博客-CSDN博客

另外:说一下我是如何发现model和layer两种类的差异的,没有遇到坑之前是万万不会理解的

from 官方文档【在类继承模型中,模型的拓扑结构是由 Python 代码定义的(而不是网络层的静态图)。这意味着该模型的拓扑结构不能被检查或序列化。因此,以下方法和属性不适用于类继承模型:】

  • model.inputsmodel.outputs
  • model.to_yaml()model.to_json()
  • model.get_config()model.save()

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值