计算图结构
- MetaGraphDef(计算图)
- MetaInfoDef(运算方法)
- OpList(运算方法集合)
- OpDef(运算方法)
- ArgDef(输入,输出)
- AttrDef(属性)
- OpDef(运算方法)
- OpList(运算方法集合)
- GraphDef (连接结构)
- NodeDef(节点)
- SaverDef (模型持久化)
- CheckpointFormatVersion(模型定义使用的版本)
- map<string, CollectionDef> (集合)
- NodeList(节点value)
- BytesList(序列化value)
- map<string, SignatureDef>(签名)
- AssetFileDef (权重值)
- MetaInfoDef(运算方法)
collection_def
message CollectionDef {
// NodeList用于收集图中的节点。
message NodeList {
repeated string value = 1;
}
// BytesList用于收集字符串和序列化的protobufs。
message BytesList {
repeated bytes value = 1;
}
// Int64List用于收集int,int64和long值。
message Int64List {
repeated int64 value = 1 [packed = true];
}
// FloatList用于收集浮点值。
message FloatList {
repeated float value = 1 [packed = true];
}
// AnyList用于收集Any protos。
message AnyList {
repeated google.protobuf.Any value = 1;
}
// 以上定义必须属于oneos中
oneof kind {
NodeList node_list = 1;
BytesList bytes_list = 2;
Int64List int64_list = 3;
FloatList float_list = 4;
AnyList any_list = 5;
}
}
案列
// 1. 对于单一的数据类型, 列如 string, int, float:
tf.add_to_collection("your_collection_name", your_simple_value)
strings 将会被保存为 bytes_list.
// 2. 对于序列化数据, 有3种方法添加:
//1)
tf.add_to_collection("your_collection_name",your_proto.SerializeToString())
collection_def {
key: "user_defined_bytes_collection"
value {
bytes_list {
value: "queue_name: \"test_queue\"\n"
}
}
}
//2)
tf.add_to_collection("your_collection_name", str(your_proto))
collection_def {
key: "user_defined_string_collection"
value {
bytes_list {
value: "\n\ntest_queue"
}
}
}
//3) any_buf = any_pb2.Any()
tf.add_to_collection("your_collection_name",any_buf.Pack(your_proto))
collection_def {
key: "user_defined_any_collection"
value {
any_list {
value {
type_url: "type.googleapis.com/tensorflow.QueueRunnerDef"
value: "\n\ntest_queue"
}
}
}
}
//对于Pyhon类型的对象, implement to_proto() 和 from_proto(), 并以下列方式在tensorflow中进行注册:
ops.register_proto_function("your_collection_name",
proto_type,
to_proto=YourPythonObject.to_proto,
from_proto=YourPythonObject.from_proto)
//并且使用这些函数来序列化和反序列化集合。例如,
ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES,
proto_type=variable_pb2.VariableDef,
to_proto=Variable.to_proto,
from_proto=Variable.from_proto)
NodeList
维护节点集合
// summaries 集合中,收集要保存的节点
collection_def {
key: "summaries"
value {
node_list {
value: "input_producer/ScalarSummary:0"
value: "shuffle_batch/ScalarSummary:0"
value: "ImageSummary:0"
}
}
}
BytesList
维护字符串或者序列化之后的集合
// 所有可以训练变量的集合,以bytes_list二级只能
collection_def {
key: "trainable_variables"
value {
bytes_list {
value: "\n\017conv1/weights:0\022\024conv1/weights/Assign\032\024conv1/weights/read:0"
value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032\023conv1/biases/read:0"
}
}
}