前言
一些预训练好的模型,可以被用来当做特征提取器,如何使用成为关键。
场景: 基于TF2.0提供的一些预训练好的2D图像分类网络,进行迁移学习。
抽取网络
basemodel = tf.keras.applications.ResNet50(weights='imagenet',
input_shape=(224, 224, 3),
include_top=Fasle)
backbone = tf.keras.Model(inputs=basemodel.input,
outputs=basemodel.get_layer('层名称').output)
其中“层名称”,可以通过basemodel.summary()输出网络结构,选取自己想要的层,然后记住名称,放入即可,然后这个backbone就可以直接使用了。