持久化原理及数据格式

    当调用saver.save函数时,TensorFlow程序会自动生成4个文件。TensorFlow模型 的持久化就是通过这4个文件完成。本文将介绍这4个文件保存的内容以及数据格式。

    TensorFlow是一个通过图的形式来表述计算的编程系统,TensorFlow程序中的所有计算都会被表达为计算图上的节点。TensorFlow通过元图(MetaGraph)来记录计算图中节点信息以及运行计算图中所需要的元数据。TensorFlow中的元图是由MetaGraphDefault Protocol Buffer定义的。MetaGraphDef 中的内容就构成了TensorFlow持久化时的第一个文件。以下代码给出了MetaGraphDefault类型的定义。

message MetaGraphDef {

    MetaInfoDef meta_info_def = 1;
    
    GraphDef graph_def = 2;
    SaverDef = 3;
    map<string, CollectionDef> collection_def = 4;
    map<string, SignatureDef signature_def = 5;
    repeated AssetFileDef asset_file_def = 6;
 }

    从以上代码可知,元图中主要记录了6类信息。保存MetaGraphDef信息的文件默认以.meta为后缀名。为了方便调试,TensorFlow提供了export_meta_graph函数,这个函数支持以json格式导出MetaGraphDef Protocol Buffer。以下代码展示了如何使用这个函数。

import tensorflow as tf

#定义变量相加的计算
v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1'))
v2 = tf.Variable(tf.constant(2.0, shape=[1], name='v2'))
result = v1+v2

saver = tf.train.Saver()
#通过export_meta_graph函数导出TensorFlow计算图的元图,并保存为json格式
saver.export_meta_graph("/path/to/model.ckpt.meda.json", as_text=True)

    通过上面给出的代码,将计算图元图以json的格式导出并存储在model.ckpt.meta.json文件中。

    之后将具体介绍TensorFlow元图中存储的信息

    meta_info_def属性

    meta_info_def属性是通过MetaInfoDef定义的,它记录了TensorFlow计算图中的元数据以及TensorFlow程序中所有使用到的运算方法的信息。下面是MetaInfoDef Protocol Buffer的定义:

message MetaInfoDef {
    string meta_graph_version = 1;
    OpList stripped_op_list = 2;
    google.protobuf.Any any_info = 3;
    repeated string tags = 4;
    string tensorflow_version = 5;
    string tensorflow_git_version = 6;
}

    TensorFlow计算图的元数据包括了计算图的版本号(meta_graph_version属性)以及用户指定的一些标签(tags属性)。如果没有在saver中特殊指定,那么这些属性都默认为空。在model.ckpt.meta.json文件中,meta_info_def属性里只有stripped_op_list属性是不为空的。stripped_op_list属性记录了TensorFlow计算图上使用的所有运算方法的信息。注意到stripped_op_list属性保存的是TensorFlow运算方法的信息,所以如果某一个云散在TensoFlow计算图中出现了多次,那么在stripped_op_list中也只会出现一次。stripped_op_list属性的类型是OpList。OpList类型是一个OpDef类型的列表,以下代码给出了OpDef类型的定义。

message OpDef {

    string name = 1;
    
    repeted ArgDef input_arg = 2;
    repeted ArgDef output_arg = 3;
    repeted AttrDef attr = 4;
    
    OpDeprecation deprecation = 8;
    string summary = 5;
    string description = 6;
    bool is_commutative = 18;
    bool is_aggregate = 16;
    bool is_stateful = 17;
    bool allows_uninitialized_input = 19;
}

    OpDef类型中前4个属性定义了一个运算最核心的信息。OpDef中第一个属性name 定义了运算的名称,这也是一个运算为一个的标识符。在TensorFlow计算图元图的其他属性中,比如下面将要介绍的GraphDef属性,将通过运算名称来引用不同的运算。OpDef的第二和第三个属性为inpit_arg和output_arg,他们定义了运算的输入和输出。因为输入和输出可以有多个,所以这两个属性都是列表(repeated)。第四个属性attr给出了其他的运算参数信息。在model.ckpt.meta.json文件中总共定义了8个运算,下面将给出比较有代表性的一个辅助说明OpDef的数据结构。

op {

    name: "Add"
    input_arg {
        name: "x"
        type_attr: "T"
    }
    input_arg {
        name: "y"
        type_attr: "T"
    }
    output_arg {
        name: "a"
        type_attr: "T"
    }
    attr {
        name: "T"
        type: "type"
        allow_values {
            list {
                type: DT_HALF
                type: DT_FLOAT
                ...
            }
        }
    }
}

    上面给出了名称为Add的运算。这个运算有2个输入和1个输出,输入和输出属性都指定了属性type_attr,并且这个属性的值为T。在OpDef的attr属性中,必须要出现名称(name)为T的属性。以上样例中,这个属性指定了运算输入输出允许的参数类型(allowed_values)。

    MetaInfoDef的tensorflow_version和tensorflow_git_version属性记录了生成当前计算图的TensorFlow版本 

graph_def属性

    graph_def属性主要记录了TensorFlow计算图上的节点信息。TensorFlow计算图的每一个节点对应了TensorFlow程序中的一个运算。因为在meta_info_def属性中已经包含了所有运算的具体信息,所以graph_def属性值关注运算的连接结构。graph_def属性是通过GraphDef Protocol Buffer定义的,GraphDef主要包含了一个NodeDef类型的列表,以下代码给出了GraphDef和NodeDef类型中包含的信息:

message GraphDef {
    repeated NodeDef node = 1;
    versionDef versions = 4;
    
    #还有一些已经不同的或者还在试验中的属性,本文不作详细介绍
}

message NodeDef {
    string name = 1;
    string op = 2;
    repeated string input = 3;
    string device = 4;
    map<string, AttrValue> attr = 5;
}

    GraphDef中的version属性比较简单,它可以存储了TensorFlow的版本号。GraphDef的主要信息都存在node属性中,它记录了TensorFlow计算图上所有的节点信息。和其他属性类似,NodeDef类型中有一个属性name,它是节点的唯一标识符。在TensorFlow中可以通过节点的名称来获取相应的节点。NodeDef类型中的op属性给出了该节点使用的TensorFlow运算方法的名称,通过这个名称可以在TensorFlow计算图元图的meta_info_def属性中找到该运算的具体信息。

    NodeDef类型中的input属性是一个字符串列表,它定义了运算的输入。input属性中每个字符串的取值格式为node:src_output,其中node部分给出了一个节点的名称,src_output部分表明了这个输入是指定节点的第几个输出。当src_output为0时,可以省略:src_output这个部分。比如node:0表示名称为node的节点的第一个输出,它也可以被记为node。

    NodeDef类型中的device属性指定了处理这个运算的设备。运行TensorFlow运算的设备可以是本地机器的CPU和GPU,也可以是远程机器的CPU和GPU。当device属性为空时,TensorFlow在运行时会自动选取一个最合适的设备来运行这个运算。最后NodeDef类型中的attr属性指定了和当前运算相关的配置信息。下面举例model.ckpt.meta.json文件中的一些计算节点来更加具体介绍graph_def属性。

graph_def {
    node {
        name: "v1"
        op: "VariableV2"
        attr { 
            key: "_output_shapes"
            value {

                list { shpe { dim { size: 1 } } }
            }
        }
        attr {
            key: "dtype"
            value: {
                type: DT_FLOAT
            }
        }
    ...
    }  
    node {
        name: "add"
        op: "Add"
        input: "v1/read"
        input: "v2/read"
        ...
    }  
    node {
        name:"save/control_dependency"
        op:"Identity"
        ...
    }
    version {
        prodeucer: 24
    }
}

    上面给出了model.ckpt.meta.json文件中graph_def属性里比较有到表型的几个节点,第一个节点给出的是变量定义的运算。在TensorFlow中变量定义也是一个运算,这个运算的名称为v1(name: "v1"),运算方法的名称为Variable(op: "Variable2")。定义变量的运算可以有很多个,于是在NodeDef类型的node属性中可以有多个变量定义的节点。但定义变量的运算方法只用到了一个,于是在MetaInfoDef类型的stripped_op_list属性汇总只有一个名称为VariableV2的运算方法。除了制定计算图中节点的名称和运算方法,NodeDef类型中还定义了运算相关的属性。在节点v1中,attr属性指定了这个变量的维度以及类型。

    给定第二个节点是代表加法运算的节点。它指定了2个输入,一个为v1/read,另一个为v2/read。其中v1/read代表的节点可以读取变量v1的值。因为v1的值是节点v1/read的第一个输出,所以后面的:0就可以省略了。v2/read也类似的代表v2的取值。以上样例文件给出的最后一个名称为save/control_dependency,改节点是系统在完成TensorFlow模型持久化过程中自动生成的一个运算。在样例的最后,属性version给出了生成model.ckpt.meta.json文件时使用的TensorFlow版本号。

saver_def属性

    saver_def属性中记录了持久化模型时需要使用到的一些参数,比如保存到文件的文件名、保存操作和加载操作的名称以及保存频率、清理历史记录等。saver_def属性的类型为SaverDef,其定义如下。

message SaverDef {
    string filename_tensor_name = 1;
    string save_tensor_name = 2;
    string restore_op_name = 3;
    int32 max_to_keep = 4;
    bool sharded = 5;
    float keep_checkpoint_every_n_hours = 6;
    
    enum CheckpointFormatVersion {
        LEGACT = 0;
        v1 = 1;
        v2 = 2;
    }
    CheckpointFormatVersion version = 7;
}

下面给出了model.ckpt.meta.json文件中saver_def属性的内容

saver_def {
    filename_tensor_name: "save/Const:0"
    save_tensor_name: "save/control_dependency:0"
    restore_p[_name: "save/restore_all"
    max_to_keep: 5
    keep_checkpoint_every_n_hours: 1000.0
    version:V2
}

filename_tensor_name属性给出了保存文件名的张量名称,这个张量就是节点save/Const的第一个输出。save_tensor_name属性给出了持久化TensorFlow模型的运算所对应的节点名称,从上面文件来看,这个节点就是在graph_def属性中给出的save/control_dependency节点。和持久化TensorFlow模型运算对应的是加载TensorFlow模型的运算,这个运算的名称由restore_op_name属性指定。max_to_keep属性和keep_checkpoint_every_n_hours属性设定了tf.train.Saver类清理之前保存的模型策略。比如当max_to_keep为5时,在第六次调用saver.save时,第一次保存的模型就会被自动删除。通过设置keep_checkpoint_every_n_hours,每n小时可以在max_to_keep的基础上多保存一个模型。

collection_def属性

    在TensorFlow的计算图(tf.Graph)中可以维护不同集合,而维护这些集合的底层实现就是通过collection_def这个属性。collection_def属性是一个从集合名称到集合内容的映射,其中集合名称为字符串,而集合内容为CollectionDef Protocol Buffer。以下代码给出了CollectionDef类型的定义。

message CollectionDef {
    message NodeList {
        repeated string value = 1;
    }
    message BytesList {
        repeated int64 value = 1 [packed = true];
    }

    message FloatList {
        repeated float value = 1 [packed = true];
    }
    message AnyList {
        repeated goole.protobuf.Any value = 1;
    }

    oneof kind {
        NodeList node_list = 1;
        BytesList bytes_list = 2;
        Int64List int64_list = 3;
        FloatList float_list = 4;
        AnyList any_list = 5;
    }
    
}

通过以上定义可以看出,TensorFlow计算图上的集合主要可以维护4类不同的集合。NodeList用户维护计算图上节点的集合。BytesList可以维护字符串或者系列化之后的Proctor Buffer的集合。比如张量是通过Protocol Buffer表示的,而张量的集合是通过BytesList维护的,我们将model.ckpt.meta.json文件中看到具体样例。Int64List用于维护整数集合,FloatList用户维护实数集合。下面给出了model.ckpt.meta.json文件中collection_def属性的内容。

collection_def {
    key: "trainable_variables"
    value {
        bytes_list {
            value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
            value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
        }   
    }
}
collection_def {
    key: "variables"
    value {
        bytes_list {
            value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
            value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
        }   
    }
}

    以上文件可以看出样例程序中维护了两个集合。一个是所有变量的集合,这个集合的名称为variables。另一个是可训练变量的集合,名为trainable_variables。在样例程序中,这两个集合中的元素是一样的,都是变量v1和v2。他们都是系统自动维护的。

     tf.Saver得到的model.ckpt.index和model.ckpt.data-*****-of-*****文件就保存了所有变量的取值。其中model.ckpt.data文件时通过SSTable格式储存的,可以大致理解为就是一个(key, value)列表。TensorFlow提供了tf.train.NewCheckpointReader类来查看保存的变量信息。以下大妈展示如何使用tf.train.NewCheckpointReader类。

import tensroflow as tf

#tf.train.NewCheckpointReader可以读取checkpoint文件中保存的所有变量
#注意后面的.data和.index可以省去
reader = tf.train.NewCheckpointReader('/path/to/model/model.ckpt')

#获取所有变量列表。这个是一个从变量名到变量维度的字典
global_variables = reader.get_variable_to_shape_map()
for variable_name in global_variable_to_shape_map()
    #variable_name为变量名称,global_variables[variable_name]为变量的维度
    print(variable_name, global_variables[variable_name])

#获取名称为v1变量的取值
print("Value for variable v1 is ", reader.get_tensor("v1"))

    最后一个文件的名字是固定的,叫checkpoint。这个文件时tf.train.Saver类自动生成且自动维护的。在checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型的文件名。当保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也从checkpoint文件删除。checkpoint中内容的格式为CheckpointState Protocol Buffer,下面给出了CheckpointState类型的定义

message CheckpointState {
    string model_checkpoint_path = 1;
    repeated string all_model_checkpoint_paths = 2;
}

    model_checkpoint_path属性保存了最新的TensorFlow模型文件的文件名。all_model_checkpoint_paths属性列出了当前还没有被删除的所有TensorFlow模型文件的文件名。

model_checkpoint_path:"/path/to/model/model.ckpt"
all_model_chekpoint_paths: "/path/to/model/model.ckpt"

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值