首先需要了解, map字段是个语法糖, 比如 node_def.proto中的如下部分:
...
import "tensorflow/core/framework/attr_value.proto";
...
message NodeDef {
string name = 1;
string op = 2;
repeated string input = 3;
string device = 4;
map<string, AttrValue> attr = 5;
};
最后一个map 其实是相当于一个嵌套的 repeated message:
message NodeDef {
string name = 1;
string op = 2;
repeated string input = 3;
string device = 4;
message attrEntry {
string key = 1;
AttrValue value = 2;
}
repeated attrEntry attr = 5;
};
此处应该用到protobuf反射: 部分代码如下:
void getattrvalue_bykey( const tensorflow::NodeDef &node, string needed_attrkey ,ReturnStruct &returnstruct ) {
if (needed_attrkey.size()> 0 ) {
// **1** : 从传入目标层 的attr属性中读信息
const google::protobuf::Reflection* pNodeReflection = node.GetReflection();
const google::protobuf::Descriptor* pNodeDescriptor = node.GetDescriptor();
const google::protobuf::FieldDescriptor* pAttrField = pNodeDescriptor->FindFieldByName("attr");
for (int j = 0; j < node.attr_size(); ++j ) {
NodeDef::NodeDef_AttrEntry attrentry ;
attrentry.MergeFrom( pNodeReflection->GetRepeatedMessage(node, pAttrField , j ) );
auto pAttrentryReflection = attrentry.GetReflection();
auto pAttrentryDescriptor = attrentry.GetDescriptor();
auto pAttrkey_field = pAttrentryDescriptor->FindFieldByName("key");
auto pAttrValue_field = pAttrentryDescriptor->FindFieldByName("value");
if ( needed_attrkey == "strides") {
//strides < == > attr的key为 "strides" 值为list 的第二和第三维。
if ( "strides" == pAttrentryReflection->GetString(attrentry , pAttrkey_field) ) {
::tensorflow::AttrValue attrvalue;
attrvalue.MergeFrom( pAttrentryReflection->GetMessage(attrentry, pAttrValue_field ) );
auto pAttrvalueReflection = attrvalue.GetReflection();
auto pAttrvalueDescriptor = attrvalue.GetDescriptor();
// **2** : 从attr的 attrvalue属性中 读list信息
auto pListvalue_field = pAttrvalueDescriptor->FindFieldByName("list");
::tensorflow::AttrValue_ListValue listvalue ;
listvalue.MergeFrom(pAttrvalueReflection->GetMessage(attrvalue, pListvalue_field ) );
auto pListvalueReflection = listvalue.GetReflection();
auto pListvalueDescriptor = listvalue.GetDescriptor();
auto pIntegervalue_field = pListvalueDescriptor->FindFieldByName("i");
//TODO: "strides" 值为list 的第二和第三维。 假设现在这两维的值是一样的。如果不一样,将来再处理
returnstruct.intvalue = pListvalueReflection->GetRepeatedInt64(listvalue, pIntegervalue_field, 1 );
if (returnstruct.intvalue != pListvalueReflection->GetRepeatedInt64(listvalue, pIntegervalue_field, 2 ) ) {
cout << "!!warning! the width and hight strides is not equal !! now only use one !!" << endl;
}
}
}
}
}
}