项目场景:
复现SRGAN 时 使用VGG19 9层 提取特征时报错。由于tensorflow版本更新所造成得。
问题描述:
代码如下
def build_vgg(self):
# 建立VGG模型,只使用第9层的特征
vgg = VGG19(weights="imagenet")
vgg.outputs = [vgg.layers[9].output]
img = Input(shape=self.hr_shape)
img_features = vgg(img)
return Model(img, img_features)
更改为:
def build_vgg(self):
# 建立VGG模型,只使用第9层的特征
vgg = VGG19(weights="imagenet",input_shape=self.hr_shape,include_top=False)
return Model(vgg.input, outputs=vgg.layers[9].output)