多个checkpoint 的参数进行平均

source_model 路径下 存在 以下几个checkpoint
model_checkpoint_path: "model.ckpt-457157707"
all_model_checkpoint_paths: "model.ckpt-456023526" ,all_model_checkpoint_paths: "model.ckpt-456332667" ,all_model_checkpoint_paths: "model.ckpt-456332668",all_model_checkpoint_paths: "model.ckpt-456832684" ,all_model_checkpoint_paths: "model.ckpt-457157707"

现在将这些ckpt的参数进行平均 合并成一个model.ckpt-457157708 

import tensorflow as tf
import numpy as np

# 获取所有的checkpoint文件
ckpt_files = ["model.ckpt-456023526", "model.ckpt-456332667", "model.ckpt-456332668", "model.ckpt-456832684", "model.ckpt-457157707"]
ckpt_files = [os.path.join("source_model", ckpt_file) for ckpt_file in ckpt_files]

# 用于存储所有模型的参数
all_model_vars = {}

for ckpt_file in ckpt_files:
    reader = tf.train.NewCheckpointReader(ckpt_file)
    model_vars = reader.get_variable_to_shape_map()
    for var in model_vars:
        if var not in all_model_vars:
            all_model_vars[var] = []
        all_model_vars[var].append(reader.get_tensor(var))

# 计算每个参数的平均值
average_vars = {var: np.mean(values, axis=0) for var, values in all_model_vars.items()}

# 创建一个新的checkpoint文件,并将平均后的参数保存到新的.data文件中
with tf.Session() as sess:
    for var_name, var_value in average_vars.items():
        var = tf.get_variable(var_name, initializer=var_value)
        sess.run(var.initializer)

    saver = tf.train.Saver()
    saver.save(sess, "source_model/model.ckpt-457157708")
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值