Tensorflow持久化原理及数据格式

Tensorflow持久化原理及数据格式


Tensorflow是一个通过图的形式来表述计算的编程系统,Tensorflow中所有的计算都会被表达成计算图上的节点。Tensorflow通过元图(MetaGraph)来记录计算图中的信息,以及运行计算图中节点所需要的元数据。以下代码给出了MetaGraphDef类型的定义
message MetaGraphDef{
    MeatInfoDef meta_info_def = 1;
    GraphDef graph_def = 2;
    SaverDef saver_def = 3;
    map<string,CollectionDef> collection_def = 4;
    map<string,SignatureDef> signature_def = 5;
}

保存MetaGraphDef信息的文件默认以.meta为后缀名,在之前的例子中文件test.ckpt.meta中存储的就是元图的数据。由于得到的是二进制文件不方便查看。为了方便调试,Tensorflow提供了export_meta_graph函数,这个函数支持以json格式导出MetaGraphDef。下面为实现的代码
import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name ="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name ="v2")
result1 = v1 + v2

saver = tf.train.Saver()
#通过export_meta_graph函数导出Tensorflow的计算图的元图,并保存为json格式
saver.export_meta_graph("test/test.ckpt.json", as_text=True)

meta_info_def 属性

meta_info_def属性通过MetaInfoDef定义,它记录了Tensorflow计算图中的元数据以及Tensorflow程序中所有使用到的运算方法的信息。下面的MetaInfoDef的定义
message MetaInfoDef{
    #saver没有特殊指定,默认属性都为空。meta_info_def属性里只有stripped_op_list属性不能为空。
    string meta_graph_version = 1;#该属性不能为空
    OpList stripped_op_list = 2;#该属性记录了计算图中使用到的所有运算方法的信息,该函数只记录运算信息,不记录计算的次数
    google.protobuf.Any any_info = 3;
    repeated string tags = 4;
}
元数据包括计算图的版本号(meta_graph_version属性)以及用户指定的一些标签(tags属性)。

OpList类型是一个OpDef类型的列表,以下代码给出OpDef类型的定义:
message opDef{
    string name = 1;#定义了运算的名称
    repeated ArgDef input_arg = 2; #定义了输入,属性是列表
    repeated ArgDef output_arg =3; #定义了输出,属性是列表
    repeated AttrDef attr = 4;#给出了其他运算的参数信息
    string summary = 5;
    string description = 6;
    OpDeprecation deprecation = 8;
    bool is_commutative = 18;
    bool is_aggregate = 16
    bool is_stateful = 17;
    bool allows_uninitialized_input = 19;
};
下面给出一个比较有代表性的运算来辅助说明OpDef的数据结构。
op {
    name: "Add"
    input_arg{
        name: "x"
        type_attr:"T"
    }
    input_arg{
        name: "y"
        type_attr:"T"
    }
    output_arg{
        name: "z"
        type_attr:"T"
    }
    attr{
        name:"T"
        type:"type"
        allow_values{
            list{
                type:DT_HALF
                type:DT_FLOAT
                ...
            }
        }
    }
}

上面给出的是名称为Add的运算。这个运算的输入有两个,输出有一个,输入输出属性均指定了属性typr_attr,并且这个属性的值为T。在OpDef的attr的属性中。必须要出现名称(name)为T的属性。以上样例中,这个属性指定了运算输入输出允许的参数类型 (allowed_values)。


graph_def属性

graph_def属性主要记录了计算图中的节点信息。Tensorflow计算图中的一个节点对应Tensorflow中的一个运算。因为meta_info_def中已经包含所有运算的具体信息,所以graph_def属性指关注运算的连接结构。GraphDef主要包含了一个NodeDef类型的列表。以下代码给出GraphDef和NodeDef类型中包含的信息:
message GraphDef{
    #GraphDef的主要信息存储在node属性中,他记录了Tensorflow计算图上所有的节点信息。
    repeated NodeDef node = 1;
    VersionDef versions = 4; #主要储存了Tensorflow的版本号
};

message NodeDef{
    #NodeDef类型中有一个名称属性name,他是一个节点的唯一标识符,在程序中,通过节点的名称来获得相应的节点。
    string name = 1;

    '''
    op属性给出了该节点使用的Tensorflow运算方法的名称。
    通过这个名称可以在TensorFlow计算图元图的meta_info_def属性中找到该运算的具体信息。
    '''
    string op = 2;

    '''
    input属性是一个字符串列表,他定义了运算的输入。每个字符串饿的取值格式为弄的:src_output
    node部分给出节点名称,src_output表明了这个输入是指定节点的第几个输出。
    src_output=0时可以省略src_output部分
    '''
    repeated string input = 3;

    #制定了处理这个运算的设备,可以是本地或者远程的CPU or GPU。属性为空时自动选择
    string device = 4;

    #制定了和当前运算有关的配置信息
    map<string, AttrValue> attr = 5;
};
下面列举test.ckpt.meta.json具体介绍graph_def属性
graph def {
    node {
        name: "v1"
        op: "Variable"
        attr {
            key:"_output_shapes"
            value {
                list{ shape { dim { size: 1 } } }
            }
        }
    }
    attr { 
        key :"dtype"
        value {
            type: DT_FLOAT
            }
        }           
        ...
    }
    node {
        name :"add"
        op :"Add"
        input :"v1/read" #read指读取变量v1的值
        input: "v2/read"
        ...
    }
    node {
        name: "save/control_dependency" #指系统在完成tensorflow模型持久化过程中自动生成一个运算。
        op:"Identity"
        ...
    }
    versions {
        producer :9 #给出了文件使用时的Tensorflow版本号。
    }
}

saver_def属性

message SaverDef {
    string filename_tensor_name = 1;
    string save_tensor_name = 2;
    string restore_op_name = 3;
    int32 max_to_keep = 4;
    bool sharded = 5;
    float keep_checkpoint_every_n_hours = 6;
    enum CheckpointFormatVersion {
        LEGACY = 0;
        V1 = 1;
        V2 = 2;
    }
    CheckpointFormatVersion version = 7;
}
下面给出test.ckpt.meta.json文件中saver_def属性的内容。
saver_def {
    filename_tensor_name :"save/Const:0”
    #给出了保存文件的张量名,这个张量就是节点save/Const的第一个输出。

    save_tensor_name :"save/control_dependency: 0#给出了持久化模型运算所对应的节点名称

    restore_op_name: "save/restore_all"
    #和持久性模型运算对应的是加载模型的运算的名称

    max_to_keep:5
    keep_checkpoint_every_n_hours :10000.0
    '''
    上面两个属性设定了tf.train.Saver类清理之前保存的模型的策略。比如当max_to_keep为5时,第六次调用
    saver.save时,第一次保存的模型就会被自动删除,通过设置keep_checkpoint_every_n_hours,每n小
    时可以在max_to_keep的基础上保存一个模型
    '''

collection_def属性

collection_def属性是一个集合名称到集合内容的映射,其中集合的名称为字符串,而集合内容为CollectionDef Protocol Buffer。以下代码给出CollectionDef类型的定义
message CollectionDef {
    message Nodelist {
    #用于维护计算图上的节点集合
        repeated string value = 1;
    }

    message BytesList {
    #维护字符串或者系列化之后的Procotol Buffer的集合。例如张量是通过Protocol Buffer表示的,而张量的集合是通过BytesList维护的。
        repeated bytes value = 1 ;
    }

    message Int64List {
        repeated int64 value = 1[packed = true];
    }
    message FloatList {
        repeated float value = 1[packed = true] ;
    }
    message AnyList {
        repeated google.protobuf.Any value= 1;
    }
    oneof kind {
        NodeList node_list = 1;
        BytesList bytes_lista = 2;
        Int64List int64_list = 3;
        Floatlist float_list = 4;
        AnyList any_list = 5;
    }
}
下面给出了test.ckpt.meta.json文件中的collection_def属性的内容
collection_def {
    #可训练变量的集合
    key: "trainable_variables"
    value {
        bytes_list {
            value; "\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
            value: "\n\004v2:0\022\tv2/Assign\032\cv2/read:0"
        }
    }
}
collection_def {
    #所有变量的集合
    key: "variables"
    value {
        bytes_list {
            value:"\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
            value:"\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
        }
    }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值