Tensorflow 模型参数保存方法

Tensorflow 模型参数保存方法

Tensorflow 保存模型参数分为两种,第一种是未剪枝参数,第二种是剪枝后的参数。对于未剪枝的参数,可以直接保存在TXT,剪枝后的参数可以按照稀疏矩阵保存方法进行保存。

 

一、未剪枝参数保存

读取参数之前需要保存模型和对应的参数,保存方法按照如下:

def save_model(self):

  saver = tf.train.Saver()

  save_path = saver.save(self.sess,"save/model.ckpt")

调用这操作后会在save下生成对应的模型,从ckpt文件中就能读取到参数

保存参数只需要读取模型的参数名和参数数据,代码示例为:

import TensorFlow as tf

 

model_dir = "save/"

ckpt = tf.train.get_checkpoint_state(model_dir)

ckpt_path = ckpt.model_checkpoint_path

# importing graph

reader = tf.train.NewCheckpointReader(ckpt_path)

all_variables = reader.get_variable_to_shape_map()

对应行数分别为

1.文件目录

2.读取文件目录的ckpt文件

3.得到ckpt文件名

4.用reader来得到对应的参数字典和参数数据

5.所有参数和对应参数的size

 

得到参数后可以用get_tensor得到tensor数据

parameter_data = reader.get_tensor(key)

 

保存在txt需要一个个数据保存,因为数据长度很大如果直接把parameter_data保存会是大量的省略号具体操作如下

for key in all_variables.keys():

  # print(key,all_variables[key])

  parameter_data = reader.get_tensor(key)

  print('**************** save', key ,' succeed******************* shape:',parameter_data.shape)

  data_shape = parameter_data.shape

  pf.write(str(key))

  pf.write(',data shape:')

  pf.write(str(all_variables[key]))

  pf.write('\n')

  if len(data_shape) == 0:

      pf.write(str(parameter_data))

 

  # save 1-D data format

  if len(data_shape) == 1:

      pf.write('{')

      for i in range(parameter_data.shape[0]):

          pf.write(str(parameter_data[i]))

          if i < parameter_data.shape[0] - 1:

              pf.write(',')

          else:

              pf.write('}')

      pf.write('\n')

 

  # save 2-D data format

  if len(data_shape) == 2:

      pf.write('{')

      for i in range(parameter_data.shape[0]):

          for j in range(parameter_data.shape[1]):

              pf.write(str(parameter_data[i][j]))

              if i < parameter_data.shape[0] - 1:

                  pf.write(',')

              else:

                  pf.write('}')

      pf.write('\n')

 

  # save 4-D data format

  if len(data_shape) == 4:

      pf.write('{')

      for i in range(parameter_data.shape[0]):

          for j in range(parameter_data.shape[1]):

              for k in range(parameter_data.shape[2]):

                  for l in range(parameter_data.shape[3]):

                      pf.write(str(parameter_data[i][j][k][l]))

                      if i < parameter_data.shape[0] - 1:

                          pf.write(',')

                      else:

                          pf.write('}')

      print('\n')

pf.close()

具体的思路就是,为了移植到C作为define 方便,按照define 的定义方式,初始和结尾用{ },中间一个个保存data 和data 后用间隔分隔,当然也可以修改保存的头

  pf.write(str(key))

  pf.write(',data shape:')

  pf.write(str(all_variables[key]))

  pf.write('\n')

这几行可以修改成直接使用的方式 #define PATAMETER_SRT DDATA这种格式

 

全部代码如下

import tensorflow as tf

import numpy as np

from tensorflow.python import pywrap_tensorflow

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

 

def view_parameter(model_dir):

  ckpt = tf.train.get_checkpoint_state(model_dir)

  # print("ckpt :",ckpt)

  ckpt_path = ckpt.model_checkpoint_path

 

  reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)

  param_dict = reader.get_variable_to_shape_map()

  # importing graph

  reader = tf.train.NewCheckpointReader(ckpt_path)

  all_variables = reader.get_variable_to_shape_map()

  print(all_variables)

  # can be view data and data sizes

  for key, val in param_dict.items():

      try:

          print(key, val)

          data = reader.get_tensor(key)

          print(data)

  #       print_tensors_in_checkpoint_file(ckpt_path, tensor_name=key, all_tensors=False, all_tensor_names=False)

      except:

          pass

  return all_variables,ckpt_path

 

def save_parameter_txt(model_dir="save/"):

  model_dir = model_dir

  ckpt = tf.train.get_checkpoint_state(model_dir)

  ckpt_path = ckpt.model_checkpoint_path

 

  # importing graph

  reader = tf.train.NewCheckpointReader(ckpt_path)

  all_variables = reader.get_variable_to_shape_map()

  print(all_variables)

 

  pf = open('parameter.txt', 'w')

  # loop save non-None data in txt

  for key in all_variables.keys():

      # print(key,all_variables[key])

      parameter_data = reader.get_tensor(key)

      print('**************** save', key ,' succeed******************* shape:',parameter_data.shape)

      data_shape = parameter_data.shape

      pf.write(str(key))

      pf.write(',data shape:')

      pf.write(str(all_variables[key]))

      pf.write('\n')

      if len(data_shape) == 0:

          pf.write(str(parameter_data))

 

      # save 1-D data format

      if len(data_shape) == 1:

          pf.write('{')

          for i in range(parameter_data.shape[0]):

              pf.write(str(parameter_data[i]))

              if i < parameter_data.shape[0] - 1:

                  pf.write(',')

              else:

                  pf.write('}')

          pf.write('\n')

 

      # save 2-D data format

      if len(data_shape) == 2:

          pf.write('{')

          for i in range(parameter_data.shape[0]):

              for j in range(parameter_data.shape[1]):

                  pf.write(str(parameter_data[i][j]))

                  if i < parameter_data.shape[0] - 1:

                      pf.write(',')

                  else:

                      pf.write('}')

          pf.write('\n')

 

      # save 4-D data format

      if len(data_shape) == 4:

          pf.write('{')

          for i in range(parameter_data.shape[0]):

              for j in range(parameter_data.shape[1]):

                  for k in range(parameter_data.shape[2]):

                      for l in range(parameter_data.shape[3]):

                          pf.write(str(parameter_data[i][j][k][l]))

                          if i < parameter_data.shape[0] - 1:

                              pf.write(',')

                          else:

                              pf.write('}')

          print('\n')

  pf.close()

 

if __name__ == '__main__':

  save_parameter_txt()

 

二、保存剪枝后的稀疏矩阵

剪枝后的稀疏矩阵存储需要单独拿出参数进行相乘再将数据存储为{row}{col}{data}

如果以全连接层为例,均为二维操作

示例如下

def weights_csc_matrix(weights_matrix, mask_matrix):

  weights_matrix = np.multiply(weights_matrix,mask_matrix)

  row_data = []

  col_data = []

  weights_data = []

  print(weights_matrix)

 

  for i in range(weights_matrix.shape[0]):

      for j in range(weights_matrix.shape[1]):

          if weights_matrix[i][j] > 0:

              row_data.append(i)

              col_data.append(j)

              weights_data.append(weights_matrix[i][j])

          else:

              continue

  print(row_data,col_data,weights_data)

  return row_data,col_data,weights_data

得到的输出结果再用TXT存储原理和上面类似。

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值