every blog every motto: A man can be destroyed but not defeated.
0. 前言
下载官方预训练模型参数,以VGG16为例
说明: 最简便的办法是用下面给出的网址,将模型参数下载后,放在.keras/model 文件夹下。如果用程序下的化比较慢
1. 正文
1.1 方法一:get_file
参数说明:
filename: 下载后保存的文件名
url: 下载地址
cache_subdir: 模型保存的文件夹
get_file返回下载后的绝对路径
# 下载模型参数
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
filename = 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5' # 下载后保存的文件名
weights_path = get_file(filename,WEIGHTS_PATH_NO_TOP,cache_subdir='models_dir')
下载的参数,会保存在如下位置:
特别说明: 函数会先检查下面文件是否有.h5文件(filename),如果不存在,才会下载。如果下载失败,也可能会不有完整的文件!!!,导致再次下载失败
**解决办法:**加一个参数,就行。
md5_hash: 验证文件是否有损坏。我试了一下,只要有参数就行,比如:checksum=‘3’。
# 下载模型参数
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
filename = 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5' # 下载后保存的文件名
checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319'
weights_path = get_file(filename,WEIGHTS_PATH_NO_TOP,md5_hash=checksum,cache_subdir='models_dir')
1.1.1 有关地址
说明: 链接包括两种,分别为:include_top=Ture/False。
VGG16:
WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5')#500MB
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5')#50MB
VGG19:
WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5')
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')
ResNet50:
WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5')
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
1.2 方法二:keras.applications
说明: 此方法获取模型 和 参数
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input
import numpy as np
model = VGG16(weights='imagenet',include_top=False)
img_path = 'elephant.jpg'
img = image.load_img(img_path,target_size=(224,224))
x = image.img_to_array(img)
x = np.expand_dims(x,axis=0)
x = preprocess_input(x)
features = model.predict(x)
1.3 补充说明
1.3.1 notop 模型是什么?
表示是否包含最后3个全连接层,用来做微调(fine-tuning) ,专门开源了这类模型。
1.3.2 include_top
关于参数 include_top,如下:
fc_model = VGG16(include_top=True)
notop_model = VGG16(include_top=False)
上例notop_model 模型,为没有全连接层的模型;fc_model为有全连接层的模型
参考文章
[1] https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/utils/get_file
[2] https://blog.csdn.net/weixin_38145317/article/details/97370541
[3] https://blog.csdn.net/a1920993165/article/details/105060972
[4] https://blog.csdn.net/sinat_26917383/article/details/72859145
[5] https://keras.io/zh/applications/