问题描述
今天依旧是是令人裂开的debug环节,下面有请本次bug闪亮登场!(核心报错请直接看最后一行)
Process Process-1:1:
Traceback (most recent call last):
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
self.run()
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/home/xuzheyang/experiment/colight/pipeline.py", line 189, in generator_wrapper
best_round=best_round
File "/home/xuzheyang/experiment/colight/generator.py", line 71, in __init__
intersection_id=str(i)
File "/home/xuzheyang/experiment/colight/GIN.py", line 127, in __init__
self.q_network_bar = self.build_network_from_copy(self.q_network)
File "/home/xuzheyang/experiment/colight/GIN.py", line 455, in build_network_from_copy
network = model_from_json(network_structure, custom_objects={"RepeatVector3D": RepeatVector3D, "MLP": MLP})
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/models.py", line 379, in model_from_json
return layer_module.deserialize(config, custom_objects=custom_objects)
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
list(custom_objects.items())))
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/engine/topology.py", line 2525, in from_config
process_layer(layer_data)
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/engine/topology.py", line 2511, in process_layer
custom_objects=custom_objects)
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 146, in deserialize_keras_object
return cls.from_config(config['config'])
File "/home/xuzheyang/anaconda3/envs/pytorch_gpu/lib/python3.6/site-packages/keras/engine/topology.py", line 1271, in from_config
return cls(**config)
TypeError: __init__() got an unexpected keyword argument 'name'
众所周知,keras框架加载含有自定义层或函数的模型的时候需要启用custom_objects,使用方法如下:
def build_network_from_copy(self, network_copy):
''' Initialize a Q network from a copy '''
network_structure = network_copy.to_json()
network_weights = network_copy.get_weights()
print(network_structure)
network = model_from_json(network_structure, custom_objects={"RepeatVector3D": RepeatVector3D, "MLP": MLP})
network.set_weights(network_weights)
network.compile(
optimizer=RMSprop(lr=self.dic_agent_conf["LEARNING_RATE"]),
loss=self.dic_agent_conf["LOSS_FUNCTION"])
return network
道理我都懂,但是以前只用过自定义函数和自定义损失函数,不会设计的class的初始化的问题,因此没碰到过此类报错。今天写了一个外部的类MLP层,于是我直接裂开。
解决方案
step1:
在自定义layer的__init__方法里加入“**kwargs”这个参数
class MLP(Layer):
def __init__(self, num_layers, hidden_dim, output_dim, **kwargs):
super(MLP, self).__init__()
self.linear_or_not =True
self.num_layer = num_layers
if num_layers < 1:
raise ValueError("number of layers should be positive")
elif num_layers == 1:
self.linear = Linear_model()
else:
self.linear_or_not = False
self.multi = Multi_model(layers=num_layers, hidden_dim=hidden_dim, output_dim=output_dim)
def call(self, input_feature, **kwargs):
if self.linear_or_not:
return self.linear(input_feature)
else:
return self.multi(input_feature)
step2:
给自定义layer重写get_config方法:
def get_config(self):
config = {'num_layers': self.num_layers, 'hidden_dim': self.hidden_dim, 'output_dim': self.output_dim}
base_config = super(MLP, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
大功告成!(再次感慨google牛逼)