问题描述
在转写GIN (GRAPH ISOMORPHISM NETWORK)的pytorch代码为keras代码时,随手使用了K.relu( )函数,导致题目的报错,报错代码如下:
def next_layer(self, h, layer, adj_matrix=None):
h2 = h
pooled = Lambda(lambda x: K.batch_dot(x[0], x[1]))([adj_matrix, h2])
pooled_rep = self.mlps[layer](pooled)
h = self.batches[layer](pooled_rep)
h = K.relu(h)
return h
debug的过程中查看变量的值,发现变量h经过 K.relu( )之前存在_keras_history属性,经过K.relu( )后就没有了。
解决方法
- 使用keras.backend内的函数
- 使用keras.layers.Lambda
修改后代码:
def next_layer(self, h, layer, adj_matrix=None):
h2 = h
pooled = Lambda(lambda x: K.batch_dot(x[0], x[1]))([adj_matrix, h2])
pooled_rep = self.mlps[layer](pooled)
h = self.batches[layer](pooled_rep)
h = Lambda(lambda x: K.relu(x))(h)
return h
再次debug查看,成功。