TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.h,caffeWeightFactory.cpp源碼研讀四

TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.h,caffeWeightFactory.cpp源碼研讀四

前言

第一篇第二篇第三篇,本篇最後將介紹用於獲取權重的函數。

TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.cpp

/*
 * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "caffeMacros.h"
#include "caffeWeightFactory.h"
#include "half.h"

using namespace nvinfer1;
using namespace nvcaffeparser1;

//...

/*
獲取layerName層與blobMsg(可能是weight或bias)相關的權重

將blobMsg裡的數據轉為trtcaffe::FLOAT型別,構造一個Weights物件後回傳,
如果中間有任何問題,那麼mOK將會被設為false,並回傳dummy的Weights物件
*/
Weights CaffeWeightFactory::getWeights(const trtcaffe::BlobProto& blobMsg, const std::string& layerName)
{
    // Always load weights into FLOAT format
    /*
    getWeights是private成員函數,可以存取private成員變數mTmpAllocs,
    而getBlobProtoData這個public成員函數則不行,
    所以這裡得透過傳參的方式來修改它
    */
    //回傳一個pair,first為轉為type型別的blobMsg裡的數據,second為其元素個數
    const auto blobProtoData = getBlobProtoData(blobMsg, trtcaffe::FLOAT, mTmpAllocs);

    if (blobProtoData.first == nullptr)
    {
        const int bits = mDataType == DataType::kFLOAT ? 32 : 16;
        std::cout << layerName << ": ERROR - " << bits << "-bit weights not found for "
                    << bits << "-bit model" << std::endl;
        mOK = false;
        return Weights{DataType::kFLOAT, nullptr, 0};
    }

    //checkForNans:回傳true代表能有效地被轉成float
    mOK &= checkForNans<float>(blobProtoData.first, int(blobProtoData.second), layerName);
    return Weights{DataType::kFLOAT, blobProtoData.first, int(blobProtoData.second)};
}

//將某層的各blob的權重轉為DataType::kFLOAT型別,並放入一個向量後回傳
std::vector<Weights> CaffeWeightFactory::getAllWeights(const std::string& layerName)
{
    std::vector<Weights> v;
    //i用於遍歷某一層的多個blob
    for (int i = 0;; i++)
    {
        //獲取layerName層的第i個blob
        auto b = getBlob(layerName, i);
        if (b == nullptr)
        {
            break;
        }
        //將b裡的數據轉為trtcaffe::FLOAT型別,構造一個Weights物件後回傳
        auto weights = getWeights(*b, layerName);
        /*
        如果weights的型別不是targetType,
		就將weights裡的內容轉換為targetType型別,
		並將weights.values記錄於mTmpAllocs中
		*/
        convert(weights, DataType::kFLOAT);
        v.push_back(weights);
    }
    return v;
}

/*
定義於TensorRT/parsers/caffe/caffeWeightFactory/weightType.h
enum class WeightType
{
    // types for convolution, deconv, fully connected
    kGENERIC = 0, // typical weights for the layer: e.g. filter (for conv) or matrix weights (for innerproduct)
    kBIAS = 1,    // bias weights

    // These enums are for BVLCCaffe, which are incompatible with nvCaffe enums below.
    // See batch_norm_layer.cpp in BLVC source of Caffe
    kMEAN = 0,
    kVARIANCE = 1,
    kMOVING_AVERAGE = 2,

    // These enums are for nvCaffe, which are incompatible with BVLCCaffe enums above
    // See batch_norm_layer.cpp in NVidia fork of Caffe
    kNVMEAN = 0,
    kNVVARIANCE = 1,
    kNVSCALE = 3,
    kNVBIAS = 4
};
*/
/*
獲取layerName層的第int(weightType)個blob(即權重或偏置量),
如果沒有,則回傳長度為0的Null weights
*/
Weights CaffeWeightFactory::operator()(const std::string& layerName, WeightType weightType)
{
    /*
    獲取layerName層的第int(weightType)個blob
    int(weightType):如果是權重則為0;如果是偏置量則為1
    */
    const trtcaffe::BlobProto* blobMsg = getBlob(layerName, int(weightType));
    if (blobMsg == nullptr)
    {
        std::cout << "Weights for layer " << layerName << " doesn't exist" << std::endl;
        RETURN_AND_LOG_ERROR(getNullWeights(), "ERROR: Attempting to access NULL weights");
        //上面都已經RETURN了,這裡的assert(0)還有用?
        assert(0);
    }
    //將blobMsg裡的數據轉為trtcaffe::FLOAT型別,構造一個Weights物件後回傳
    return getWeights(*blobMsg, layerName);
}

參考連結

TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.h,caffeWeightFactory.cpp源碼研讀一

TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.h,caffeWeightFactory.cpp源碼研讀二

TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.h,caffeWeightFactory.cpp源碼研讀三

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值