1.在opennmt框架代码中ensemble的做法是在opennmt/utils/checkpoint.py
1 def average_checkpoints(model_dir, output_dir, max_count=8, session_config=None): 2 """Averages checkpoints. 3 4 Args: 5 model_dir: The directory containing checkpoints. 6 output_dir: The directory that will contain the averaged checkpoint. 7 max_count: The maximum number of checkpoints to average. 8 session_config: Configuration to use when creating the session. 9 10 Returns: 11 The path to the directory containing the averaged checkpoint. 12 13 Raises: 14 ValueError: if :obj:`output_dir` is the same as :obj:`model_dir`. 15 """ 16 if model_dir == output_dir: 17 raise ValueError("Model and output directory must be different") 18 19 checkpoints_path = tf.train.get_checkpoint_state(model_dir).all_model_checkpoint_paths 20 if len(checkpoints_path) > max_count: 21 checkpoints_path = checkpoints_path[-max_count:] 22 num_checkpoints = len(checkpoints_path) 23 24 tf.logging.info("Averaging %d checkpoints..." % num_checkpoints) 25 tf.logging.info("Listing variables...") 26 27 new_variables = {} 28 for i, checkpoint_path in enumerate(checkpoints_path): 29 tf.logging.info("Loading checkpoint %s" % checkpoint_path) 30 variables = get_checkpoint_variables(checkpoint_path) 31 for name, value in six.iteritems(variables): 32 if _variable_is_trainable(name, value): 33 scaled_value = value / num_checkpoints 34 if name in new_variables: 35 new_variables[name] += scaled_value 36 else: 37 new_variables[name] = scaled_value 38 elif i + 1 == num_checkpoints: # Take non trainable variables from the last checkpoint. 39 new_variables[name] = value 40 41 return _create_checkpoint_from_variables( 42 new_variables, 43 output_dir, 44 session_config=session_config)
对模型中全部参数进行平均,平均后创建新的checkpoint
2.在nmt中https://github.com/tensorflow/nmt tensorflow的一个框架(工程师个人框架),中
emsemble的做法在nmt/model_helper.py中,函数avg_checkpoints
做法和1基本一致,对于不训练的变量等问题处理细节不同,思路都是将所有的权重取平均。
def avg_checkpoints(model_dir, num_last_checkpoints, global_step, global_step_name): """Average the last N checkpoints in the model_dir.""" checkpoint_state = tf.train.get_checkpoint_state(model_dir) if not checkpoint_state: utils.print_out("# No checkpoint file found in directory: %s" % model_dir) return None # Checkpoints are ordered from oldest to newest. checkpoints = ( checkpoint_state.all_model_checkpoint_paths[-num_last_checkpoints:]) if len(checkpoints) < num_last_checkpoints: utils.print_out( "# Skipping averaging checkpoints because not enough checkpoints is " "avaliable." ) return None avg_model_dir = os.path.join(model_dir, "avg_checkpoints") if not tf.gfile.Exists(avg_model_dir): utils.print_out( "# Creating new directory %s for saving averaged checkpoints." % avg_model_dir) tf.gfile.MakeDirs(avg_model_dir) utils.print_out("# Reading and averaging variables in checkpoints:") var_list = tf.contrib.framework.list_variables(checkpoints[0]) var_values, var_dtypes = {}, {} for (name, shape) in var_list: if name != global_step_name: var_values[name] = np.zeros(shape) for checkpoint in checkpoints: utils.print_out(" %s" % checkpoint) reader = tf.contrib.framework.load_checkpoint(checkpoint) for name in var_values: tensor = reader.get_tensor(name) var_dtypes[name] = tensor.dtype var_values[name] += tensor for name in var_values: var_values[name] /= len(checkpoints) # Build a graph with same variables in the checkpoints, and save the averaged # variables into the avg_model_dir. with tf.Graph().as_default(): tf_vars = [ tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name]) for v in var_values ] placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] global_step_var = tf.Variable( global_step, name=global_step_name, trainable=False) saver = tf.train.Saver(tf.all_variables()) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(var_values)): sess.run(assign_op, {p: value}) # Use the built saver to save the averaged checkpoint. Only keep 1 # checkpoint and the best checkpoint will be moved to avg_best_metric_dir. saver.save( sess, os.path.join(avg_model_dir, "translate.ckpt")) return avg_model_dir
都是平均后重新构建一个新的checkpoint
3.opennmt中pythorch版本是包含emsemble的,
一个使用tensorflow来构建emsemble的例子https://github.com/zhaocq-nlp/NJUNMT-tf,构建在 njunmt/models/ensemble_model.py和