【Tensorflow】下载预训练模型和参数小结

every blog every motto: A man can be destroyed but not defeated.

0. 前言

下载官方预训练模型参数,以VGG16为例
说明: 最简便的办法是用下面给出的网址,将模型参数下载后,放在.keras/model 文件夹下。如果用程序下的化比较慢

1. 正文

1.1 方法一:get_file

参数说明:
filename: 下载后保存的文件名
url: 下载地址
cache_subdir: 模型保存的文件夹
get_file返回下载后的绝对路径

# 下载模型参数
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
filename = 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5' # 下载后保存的文件名
weights_path = get_file(filename,WEIGHTS_PATH_NO_TOP,cache_subdir='models_dir')

下载的参数,会保存在如下位置:
在这里插入图片描述
特别说明: 函数会先检查下面文件是否有.h5文件(filename),如果不存在,才会下载。如果下载失败,也可能会不有完整的文件!!!,导致再次下载失败
**解决办法:**加一个参数,就行。
md5_hash: 验证文件是否有损坏。我试了一下,只要有参数就行,比如:checksum=‘3’。

# 下载模型参数
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
filename = 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5' # 下载后保存的文件名
checksum = '3e9f4e4f77bbe2c9bec13b53ee1c2319'
weights_path = get_file(filename,WEIGHTS_PATH_NO_TOP,md5_hash=checksum,cache_subdir='models_dir')

1.1.1 有关地址

说明: 链接包括两种,分别为:include_top=Ture/False。
VGG16:

WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5')#500MB
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5')#50MB

VGG19:

WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5')
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')

ResNet50:

WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5')
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')

1.2 方法二:keras.applications

说明: 此方法获取模型参数

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input
import numpy as np

model = VGG16(weights='imagenet',include_top=False)

img_path = 'elephant.jpg'
img = image.load_img(img_path,target_size=(224,224))
x = image.img_to_array(img)
x = np.expand_dims(x,axis=0)
x = preprocess_input(x)
features = model.predict(x)

1.3 补充说明

1.3.1 notop 模型是什么?

表示是否包含最后3个全连接层,用来做微调(fine-tuning) ,专门开源了这类模型。

1.3.2 include_top

关于参数 include_top,如下:

fc_model = VGG16(include_top=True)
notop_model = VGG16(include_top=False)

上例notop_model 模型,为没有全连接层的模型;fc_model为有全连接层的模型

参考文章

[1] https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/utils/get_file
[2] https://blog.csdn.net/weixin_38145317/article/details/97370541
[3] https://blog.csdn.net/a1920993165/article/details/105060972
[4] https://blog.csdn.net/sinat_26917383/article/details/72859145
[5] https://keras.io/zh/applications/

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

胡侃有料

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值