TF2.0使用预训练网络与模型存取

前言:

Tensorflow2.0的keras中为我们封装了许多优秀的模型结构,并且这些模型结构基于imagenet数据集已经trian好,我们可以直接拿来使用,非常方便;这样做的最大好处可以解决训练集不足的问题。 

 常见预训练网络(基于ImageNet):

  •  使用预训练网络

TF2.0中的预训练网络都在包tf.keras.applications中,以VGG16为例:

"""
include_top----表示是否包含全连接层及其输出层
weights='imagenet'-------表示使用网络的默认参数
pooling=None------表示最后一层是否包含Garoupmaxpool
"""
con_base = tf.keras.applications.VGG16(include_top = False,weights = "imagenet",pooling = "max")

model = tf.keras.Sequential([con_base,
          tf.keras.layers.Dense(512,activation = 'relu'),
          tf.keras.layers.Dense(1,activation = 'sigmoid')
         ])

#设置预训练网络con_base的参数不会发生改变
con_base.trainable = False

#开始训练
model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss = tf.keras.losses.binary_crossentropy,
              metrics=["acc"])
  • 模型检查点保存与恢复

方式一:

"""
保存模型方式五:
自定义训练保存模型
"""

"""
第一步:创建检查点保存路径
"""
cp_dir = "F:\MachineLearnDatas\model_save\custom_train_save"
cp_profix = os.path.join(cp_dir,"ckpt")


"""
第二步:创建模型检查点对象
"""
check_point = tf.train.Checkpoint(optimizer = optimizers,
                                  model = model)

"""
第三步:在自定义训练函数中保存检查点
"""
if step % 2 == 0:
    check_point.save(file_prefix = cp_profix)

若模型在训练过程中,出现中断,使用下列方法提取最新检查点,并从该检查开始继续训练,代码如下:

"""
第一步:提取最新的检查点
"""
latestnew_checkpoint = tf.train.latest_checkpoint(cp_dir)

"""
第二步:创建模型检查点对象
注意:这个optimizers与model属于新创建的,还没有加载参数.
"""
check_point = tf.train.Checkpoint(optimizer = optimizers,
                                  model = model)

"""
第三步:开始恢复到最新检查点处
"""
check_point.restore(latestnew_checkpoint)

方式二(通过checkpointManager):

 保存检查点:

""
第一步:
创建检查点对象,并将模型(参数)、优化器等配置到检查点对象中
""
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
""
第二步:
创建检查点管理器对象,它可以帮我们管理检查点对象
""
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
""
第三步:
在训练函数中,设置多少轮保存一下检查点,返回值为保存路径
""
save_path = manager.save()

 恢复检查点:

""
第一步:
创建优化器、模型对象
""
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
""
第二步:
创建检查点、检查点管理器对象
""
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
""
第三步:
在训练前,恢复检查点
""
ckpt.restore(manager.latest_checkpoint)

最后:

""
另外检查点对象含有一个全局变量step
在训练中动态修改step的值
""
ckpt.step.assign_add(1)

""
获取该值,恢复检查点成功与否,也可以通过查看该值得知
""
int(ckpt.step)

""
打印最新的3个检查点
""
print(manager.checkpoints) 
  • 模型存取

"""
保存模型方式一:
tf.keras提供了使用HDF5标准提供基本的保存格式
这种方法保存了以下内容:
      1)模型权重值
      2)模型结构
      3)模型/优化器配置
"""
model.save("F:/MachineLearnDatas/model_save/traditional_save/my_model.h5")

#将模型加载出来---可以直接进行预测
save_model = tf.keras.models.load_model("F:/MachineLearnDatas/model_save/traditional_save/my_model.h5")

"""
保存模型方式二:
仅仅保存模型结构----这种方式要将模型结构保存成json格式,仅仅保存模型的结构,优化器、损失函数都未指定
"""
model_jsons = model.to_json()

#将该json文件写入到磁盘
with open("F:/MachineLearnDatas/model_save/traditional_save/model_structure.json","w") as my_writer:
    my_writer.write(model_jsons)

#将以json文件保存的结构加载出来
with open("F:/MachineLearnDatas/model_save/traditional_save/model_structure.json","r") as my_reader:
    model_structure = my_reader.read()

new_model_structure = tf.keras.models.model_from_json(model_structure)

"""
保存方式三:
仅保存权重,有时我们只需要保存模型的状态(其权重值),而对模型架构不感兴趣。在这种情况下,
可以通过get_weights()获取权重值,并通过set_weights()设置权重值
"""
model_weights = model.get_weights()

#使用第二种模式只加载出模型的结构
new_model_structure2 = tf.keras.models.model_from_json(model_structure)

new_model_structure2.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

#测试该模型的分数,即训练程度
new_model_structure2.evaluate(test_images,test_labels)

new_model_structure2.set_weights(model_weights)

"""
保存方式四:
在训练过程中,保存检查点,在训练期间或者训练结束的时候自动保存检查点。这样一来,在训练中断了后,
可以从该检查点继续向下训练。
使用的回调函数:tf.keras.callbacks.ModelCheckpoint()
"""
checkpoint_path = "F:/MachineLearnDatas/model_save/traditional_save/cp.ckpt"

check_point_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_path,
                                   save_weights_only = True)

#添加fit的callbacks中,这步切记不能忘,这样在训练的时候可以自动帮我们保存参数的检查点
model2.fit(dataset, epochs=5, steps_per_epoch=steps_per_epoch,callbacks=[check_point_callback])

"""
加载检查点,继续开始训练
"""
#加载检查点的权重
new_model_structure3.load_weights(checkpoint_path)

new_model_structure3.evaluate(test_images,test_labels)

2020年9月21日,记住这一天,机会一定只会留给有准备的人,加油!

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值