当在tensorflow中使用类定义来编写网络结构时,使用model.summary()方法输出的output shape 有可能为multiple,以下提供三种解决办法。
原文链接:python - model.summary() can't print output shape while using subclass model - Stack Overflow
解决方式1:
第一步:在__init__中添加一个input_shape参数
第二步:添加一个input_layer,
self.input_layer = tf.keras.layers.Input(input_shape)
第三步:调用call方法
self.out = self.call(self.input_layer)
示例:
class MyModel(tf.keras.Model):
def __init__(self,input_shape=(32,32,1), **kwargs):
super(MyModel, self).__init__(**kwargs)
self.input_layer = tf.keras.layers.Input(input_shape)
self.dense10 = tf.keras.layers.Dense(10, activation=tf.keras.activations.softmax)
self.dense20 = tf.keras.layers.Dense(20, activation=tf.keras.activations.softmax)
self.out = self.call(self.input_layer)
def call(self, inputs):
x = self.dense10(inputs)
y_pred = self.dense20(x)
return y_pred
model = MyModel()
model(x_test[:99])
print('x_test[:99].shape:',x_test[:10].shape)
model.summary()
解决方式2:
在使用model.summary方法前,调用.build()和.call()方法。
import tensorflow as tf
from tensorflow.keras import Input, layers, Model
class subclass(Model):
def __init__(self):
super(subclass, self).__init__()
self.conv = layers.Conv2D(28, 3, strides=1)
def call(self, x):
return self.conv(x)
if __name__ == '__main__':
model = subclass()
model.build(input_shape=(None, 24, 24, 3))
# Adding this call to the call() method solves it all
model.call(Input(shape=(24, 24, 3)))
# And the summary() outputs all the information
model.summary()
解决方法3:
在类中添加一个自定义方法,调用call()方法。
class subclass(Model):
def __init__(self):
...
def call(self, x):
...
def model(self):
x = Input(shape=(24, 24, 3))
return Model(inputs=[x], outputs=self.call(x))
if __name__ == '__main__':
sub = subclass()
sub.model().summary()