every blog every motto: You can do more than you think.
0. 前言
用子类API(Subclassing API)构建模型时出现的问题,小结。
说明: 有关子类API的部分问题探讨,参见上节
1. 正文
对于构建完以后的模型,我们往往会打印其网络结构,一方面:验证模型是否通顺,另一方面,检查模型特征图变化。,当我们使用model.summary()时出现如下错误。
ValueError: This model has not yet been built. Build the model first by calling `build()` or calling `fit()` with some data, or specify an `input_shape` argument in the first layer(s) for automatic build.
1.1 完整代码
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU
class Models(tf.keras.Model):
def __init__(self):
super().__init__()
self.conv = Conv2D(16, (3, 3), padding='same')
self.bn = BatchNormalization()
self.ac = ReLU()
self.conv2 = Conv2D(32, (3, 3), padding='same')
self.bn2 = BatchNormalization()
self.ac2 = ReLU()
def call(self, x, **kwargs):
x = self.conv(x)
x = self.bn(x)
x = self.ac(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.ac2(x)
return x
m = Models()
m.summary()
1.2 原因及解决方法
1.2.1 原因
- 在使用Sequential 或Functional API,已经“自动”(implicitly)帮你调用了model.build方法;而通过上面的代码(子类API)构建模型,需要自己调用model.build方法(explicitly)
- Sequential / Functional API是一种有向无环图的数据结构,当在第一层指定输入数据的shape时,模型可以自动推断所有层的shape并build模型,同样的,你也可以用过model.summary()打印各层的shape;而子类API是python的call方法定义的,所以没有各层的图结构,也就不能自动推断各层是如何连接的,更不能推断各层的shape。
当使用子类API时,我们需要对模型进行build。
1.2.2 解决方法
m = Models()
m.build(input_shape=(2,8,8,3))
m.summary()