TensorRT/parsers/caffe/caffeParser/readProto.h源碼研讀
TensorRT/parsers/caffe/caffeParser/readProto.h
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TRT_CAFFE_PARSER_READ_PROTO_H
#define TRT_CAFFE_PARSER_READ_PROTO_H
#include <fstream>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"
#include "caffeMacros.h"
#include "trtcaffe.pb.h"
namespace nvcaffeparser1
{
// There are some challenges associated with importing caffe models. One is that
// a .caffemodel file just consists of layers and doesn't have the specs for its
// input and output blobs.
//
// So we need to read the deploy file to get the input
//從.caffemodel中讀取模型權重到net裡
bool readBinaryProto(trtcaffe::NetParameter* net, const char* file, size_t bufSize)
{
//因為是macro函數,所以結尾不加分號
//如果net或file是nullptr,就return false
CHECK_NULL_RET_VAL(net, false)
CHECK_NULL_RET_VAL(file, false)
using namespace google::protobuf::io;
std::ifstream stream(file, std::ios::in | std::ios::binary);
if (!stream)
{
//輸出錯誤訊息,並回傳false
RETURN_AND_LOG_ERROR(false, "Could not open file " + std::string(file));
}
//創建一個從C++ istream裡讀取數據的流
IstreamInputStream rawInput(&stream);
/*
IstreamInputStream為ZeroCopyInputStream的子類別
從ZeroCopyInputStream讀取數據並解碼。
*/
CodedInputStream codedInput(&rawInput);
//設定CodedInputStream物件將讀取的最大bytes數,第二個參數將被忽略
codedInput.SetTotalBytesLimit(int(bufSize), -1);
/*
從給定的input stream裡解析出protocol buffer,
並填入net這個message物件中
*/
bool ok = net->ParseFromCodedStream(&codedInput);
stream.close();
if (!ok)
{
RETURN_AND_LOG_ERROR(false, "Could not parse binary model file");
}
return ok;
}
//從deploy.prototxt中讀取模型架構到net裡
bool readTextProto(trtcaffe::NetParameter* net, const char* file)
{
//因為是macro函數,所以結尾不加分號
CHECK_NULL_RET_VAL(net, false)
CHECK_NULL_RET_VAL(file, false)
using namespace google::protobuf::io;
std::ifstream stream(file, std::ios::in);
if (!stream)
{
RETURN_AND_LOG_ERROR(false, "Could not open file " + std::string(file));
}
//創建一個從C++ istream裡讀取數據的流
IstreamInputStream input(&stream);
/*
從給定的ZeroCopyInputStream裡讀取並解析文字格式的protocol message,
存到給定的Message物件net當中
*/
bool ok = google::protobuf::TextFormat::Parse(&input, net);
stream.close();
return ok;
}
} //namespace nvcaffeparser1
#endif //TRT_CAFFE_PARSER_READ_PROTO_H
trtcaffe.pb.h
TensorRT的源碼中並沒有trtcaffe.pb.h
這個檔案,那麼它是從何而來呢?詳見Protocol Buffer(proto2)及C++ API。
std::ifstream
readBinaryProto
及readTextProto
兩個函數中都用到了std::ifstream
,詳見C++ ifstream。
google::protobuf::io
readBinaryProto
及readTextProto
兩個函數中都用到了來自 google/protobuf
套件的函數,詳見 C++ google protobuf。