custom_objects参数的使用,以及TypeError: __init__() got an unexpected keyword argument name

问题描述

今天依旧是是令人裂开的debug环节,下面有请本次bug闪亮登场!(核心报错请直接看最后一行)

Process Process-1:1:
Traceback (most recent call last):
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
self.run()
 File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/home/xuzheyang/experiment/colight/pipeline.py", line 189, in generator_wrapper
best_round=best_round
File "/home/xuzheyang/experiment/colight/generator.py", line 71, in __init__
intersection_id=str(i)
File "/home/xuzheyang/experiment/colight/GIN.py", line 127, in __init__
self.q_network_bar = self.build_network_from_copy(self.q_network)
File "/home/xuzheyang/experiment/colight/GIN.py", line 455, in build_network_from_copy
network = model_from_json(network_structure, custom_objects={"RepeatVector3D": RepeatVector3D, "MLP": MLP})
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/models.py", line 379, in model_from_json
return layer_module.deserialize(config, custom_objects=custom_objects)
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
list(custom_objects.items())))
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/engine/topology.py", line 2525, in from_config
process_layer(layer_data)
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/engine/topology.py", line 2511, in process_layer
custom_objects=custom_objects)
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 146, in deserialize_keras_object
return cls.from_config(config['config'])
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/engine/topology.py", line 1271, in from_config
return cls(**config)
TypeError: __init__() got an unexpected keyword argument 'name'

众所周知,keras框架加载含有自定义层或函数的模型的时候需要启用custom_objects,使用方法如下:

    def build_network_from_copy(self, network_copy):

        ''' Initialize a Q network from a copy '''
        network_structure = network_copy.to_json()
        network_weights = network_copy.get_weights()
        print(network_structure)
        network = model_from_json(network_structure, custom_objects={"RepeatVector3D": RepeatVector3D, "MLP": MLP})
        network.set_weights(network_weights)

        network.compile(
            optimizer=RMSprop(lr=self.dic_agent_conf["LEARNING_RATE"]),
            loss=self.dic_agent_conf["LOSS_FUNCTION"])

        return network

道理我都懂,但是以前只用过自定义函数和自定义损失函数,不会设计的class的初始化的问题,因此没碰到过此类报错。今天写了一个外部的类MLP层,于是我直接裂开。

解决方案

step1:

在自定义layer的__init__方法里加入“**kwargs”这个参数

class MLP(Layer):
    def __init__(self, num_layers, hidden_dim, output_dim, **kwargs):
        super(MLP, self).__init__()

        self.linear_or_not =True
        self.num_layer = num_layers

        if num_layers < 1:
            raise ValueError("number of layers should be positive")
        elif num_layers == 1:
            self.linear = Linear_model()
        else:
            self.linear_or_not = False
            self.multi = Multi_model(layers=num_layers, hidden_dim=hidden_dim, output_dim=output_dim)

    def call(self, input_feature, **kwargs):
        if self.linear_or_not:
            return self.linear(input_feature)
        else:
            return self.multi(input_feature)
step2:

给自定义layer重写get_config方法:

    def get_config(self):
        config = {'num_layers': self.num_layers, 'hidden_dim': self.hidden_dim, 'output_dim': self.output_dim}
        base_config = super(MLP, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

大功告成!(再次感慨google牛逼)

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值