今天在写keras模型的时候,出现了一个小bug,完全由于基础不行导致的,在数据流向下一层的时候我需要添加一些自己的数据变换。这时候用到了Lambda函数,这时候出现了一个bug,"RecursionError: maximum recursion depth exceeded“。在看错误提示的时候,我估计是在保存模型的时候出现了死循环,导致出现内存错误。下面看下错误代码。
pool2 = MaxPooling2D(pool_size=[1, 4], padding='valid')(conv4_1)
conv5_1 = Lambda(lambda x: self.Seq_Conv(x), name='conv5_1')(pool2)
attention_probs = Dense(22, activation='softmax', name='attention_probs')(conv5_1)
attention_mul = Multiply()([conv5_1, attention_probs])
def Seq_Conv(self, inputs):
# len = inputs.shape[3] / output
len = 64
output = 22
X = None
for i in range(output):
if i == 0:
con = Conv2D(filters=1, kernel_size=[1, 1],
padding='valid', activation=None)(inputs[:, :, :, :(i+1)*len])
elif i < output-1:
con = Conv2D(filters=1, kernel_size=[1, 1],
padding='valid', activation=None)(inputs[:, :, :, i*len:(i+1)*len])
if X is None:
X = con
else:
X = Concatenate()([X, con])
return X
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/site-packages/keras/engine/network.py", line 1263, in __getstate__
return saving.pickle_model(self)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/site-packages/keras/engine/saving.py", line 429, in pickle_model
_serialize_model(model, f)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/site-packages/keras/engine/saving.py", line 83, in _serialize_model
model_config['config'] = model.get_config()
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/site-packages/keras/engine/network.py", line 931, in get_config
return copy.deepcopy(config)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 155, in deepcopy
y = copier(x, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 243, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 155, in deepcopy
y = copier(x, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 218, in _deepcopy_list
y.append(deepcopy(a, memo))
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 155, in deepcopy
y = copier(x, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 243, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 155, in deepcopy
y = copier(x, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 243, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 155, in deepcopy
y = copier(x, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 223, in _deepcopy_tuple
y = [deepcopy(a, memo) for a in x]
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 223, in <listcomp>
y = [deepcopy(a, memo) for a in x]
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 155, in deepcopy
y = copier(x, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 223, in _deepcopy_tuple
y = [deepcopy(a, memo) for a in x]
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 223, in <listcomp>
y = [deepcopy(a, memo) for a in x]
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 182, in deepcopy
y = _reconstruct(x, rv, 1, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 297, in _reconstruct
state = deepcopy(state, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 155, in deepcopy
y = copier(x, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 243, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 155, in deepcopy
y = copier(x, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 243, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/joerg/applications/envs/tud_lm/lib/python3.5/copy.py", line 151, in deepcopy
cls = type(x)
RecursionError: maximum recursion depth exceeded
在百度上搜索了许多内容,最终在一篇博客和GitHub中得到解决,链接如下:
博客
GitHub
根据以上两个知识点做出了如下的代码修改:
conv5_1 = Lambda(self.Seq_Conv, name='conv5_1')(pool2)
只需要这么规范的写就不会出现问题了。