TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.h,caffeWeightFactory.cpp源碼研讀二

前言

接上篇(TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.h,caffeWeightFactory.cpp源碼研讀一),本篇將繼續介紹跟BlobProto相關的函數,隨機生成權重的函數及用於轉換權重型別的函數。

TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.cpp

/*
 * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "caffeMacros.h"
#include "caffeWeightFactory.h"
#include "half.h"

using namespace nvinfer1;
using namespace nvcaffeparser1;

//...

//獲取神經網路某層權重種類(如:weight,bias)的數量
int CaffeWeightFactory::getBlobsSize(const std::string& layerName)
{
    for (int i = 0, n = mMsg.layer_size(); i < n; ++i)
    {
        if (mMsg.layer(i).name() == layerName)
        {
            return mMsg.layer(i).blobs_size();
        }
    }
    return 0;
}

//回傳mMsg裡名為layerName的層裡的第index個blob
const trtcaffe::BlobProto* CaffeWeightFactory::getBlob(const std::string& layerName, int index)
{
    if (mMsg.layer_size() > 0)
    {
        for (int i = 0, n = mMsg.layer_size(); i < n; i++)
        {
            if (mMsg.layer(i).name() == layerName && index < mMsg.layer(i).blobs_size())
            {
                return &mMsg.layer(i).blobs(index);
            }
        }
    }
    else
    {
        //else裡的內容跟if一樣?
        //n = 0的情況下難道不應該直接忽略?
        for (int i = 0, n = mMsg.layers_size(); i < n; i++)
        {
            if (mMsg.layers(i).name() == layerName && index < mMsg.layers(i).blobs_size())
            {
                return &mMsg.layers(i).blobs(index);
            }
        }
    }

    return nullptr;
}

//回傳dummy的Weights物件
Weights CaffeWeightFactory::getNullWeights()
{
    /*
    class Weights
    定義於include/NvInferRuntime.h
    class Weights
	{
	public:
	    DataType type;      //!< The type of the weights.
	    const void* values; //!< The weight values, in a contiguous array.
	    int64_t count;      //!< The number of weights in the array.
	};
	用於表示神經網路內各層的權重
    */
    return Weights{mDataType, nullptr, 0};
}

/*
分配getDataType()型別,elems大小的空間,用它來新建Weights物件後回傳
採用uniform real distribution來隨機初始化
*/
Weights CaffeWeightFactory::allocateWeights(int64_t elems, std::uniform_real_distribution<float> distribution)
{
    void* data = malloc(elems * getDataTypeSize());

    switch (getDataType())
    {
    case DataType::kFLOAT:
        for (int64_t i = 0; i < elems; ++i)
        {
            ((float*) data)[i] = distribution(generator);
        }
        break;
    case DataType::kHALF:
        for (int64_t i = 0; i < elems; ++i)
        {
	        /*
	        float16
	        定義於TensorRT/parsers/common/half.h
	        typedef half_float::half float16;
	        */
            //硬把flloat轉成float16不怕underflow或overflow?
            ((float16*) data)[i] = (float16)(distribution(generator));
        }
        break;
    default:
        break;
    }

    mTmpAllocs.push_back(data);
    return Weights{getDataType(), data, elems};
}

/*
分配getDataType()型別,elems大小的空間,用它來新建Weights物件後回傳
採用normal distribution來隨機初始化
*/
Weights CaffeWeightFactory::allocateWeights(int64_t elems, std::normal_distribution<float> distribution)
{
    void* data = malloc(elems * getDataTypeSize());

    switch (getDataType())
    {
    case DataType::kFLOAT:
        for (int64_t i = 0; i < elems; ++i)
        {
            ((float*) data)[i] = distribution(generator);
        }
        break;
    case DataType::kHALF:
        for (int64_t i = 0; i < elems; ++i)
        {
            ((float16*) data)[i] = (float16)(distribution(generator));
        }
        break;
    default:
        break;
    }

    mTmpAllocs.push_back(data);
    return Weights{getDataType(), data, elems};
}

//注意參數ptr是指標的指標
/*
新建一個OUTPUT型別的指標,把*ptr所指向的內容轉換過後存到裡面,
然後將*ptr設為該新指標並回傳該新指標
如果中間出現問題,則將mOK設為false
*/
template <typename INPUT, typename OUTPUT>
void* convertInternal(void** ptr, int64_t count, bool* mOK)
{
    assert(ptr != nullptr);
    if (*ptr == nullptr)
    {
        return nullptr;
    }
    if (!count)
    {
        return nullptr;
    }
    auto* iPtr = static_cast<INPUT*>(*ptr);
    auto* oPtr = static_cast<OUTPUT*>(malloc(count * sizeof(OUTPUT)));
    for (int i = 0; i < count; ++i)
    {
        //檢查iPtr[i]是否落在新型別的有效範圍內
        if (static_cast<OUTPUT>(iPtr[i]) > std::numeric_limits<OUTPUT>::max()
            || static_cast<OUTPUT>(iPtr[i]) < std::numeric_limits<OUTPUT>::lowest())
        {
            //lowest及max的順序是否應對調?
            std::cout << "Error: Weight " << iPtr[i] << " is outside of [" << std::numeric_limits<OUTPUT>::max()
                      << ", " << std::numeric_limits<OUTPUT>::lowest() << "]." << std::endl;
            if (mOK)
            {
                (*mOK) = false;
            }
            break;
        }
        oPtr[i] = iPtr[i];
    }
    //將ptr所指向的指標設為oPtr
    (*ptr) = oPtr;
    return oPtr;
}

/*
如果weights的型別不是targetType,
就將weights裡的內容轉換為targetType型別,
並將weights.values記錄於mTmpAllocs中
*/
void CaffeWeightFactory::convert(Weights& weights, DataType targetType)
{
    void* tmpAlloc{nullptr};
    //目前只支持單精度和半精度互轉?
    if (weights.type == DataType::kFLOAT && targetType == DataType::kHALF)
    {
        //&weights.values:指標的指標
        //convertInternal:將weights.values裡的內容轉換為targetType型別,並回傳該指標
        tmpAlloc = convertInternal<float, float16>(const_cast<void**>(&weights.values), weights.count, &mOK);
        weights.type = targetType;
    }
    if (weights.type == DataType::kHALF && targetType == DataType::kFLOAT)
    {
        tmpAlloc = convertInternal<float16, float>(const_cast<void**>(&weights.values), weights.count, &mOK);
        weights.type = targetType;
    }
    if (tmpAlloc)
    {
        mTmpAllocs.push_back(tmpAlloc);
    }
}

/*
如果weights的型別不是初始化時就設定好的mDataType型別,
就將weights裡的內容轉換為targetType型別,
並將weights.values記錄於mTmpAllocs中
*/
void CaffeWeightFactory::convert(Weights& weights)
{
    convert(weights, getDataType());
}

std::uniform_real_distribution及std::normal_distribution

CaffeWeightFactory::allocateWeights函數中,用到了std::uniform_real_distribution<float>std::normal_distribution<float>,詳見C++ uniform_real_distribution及normal_distribution

std::numeric_limits<T>::max()及lowest()

convertInternal函數中,用到了std::numeric_limits<OUTPUT>::max()std::numeric_limits<OUTPUT>::lowest(),詳見C++ std::numeric_limits::max(),min()及lowest()

const_cast

CaffeWeightFactory::convert函數中用到了const_cast,詳見C++ const_cast

參考連結

C++ uniform_real_distribution及normal_distribution

C++ std::numeric_limits::max(),min()及lowest()

C++ const_cast

TensorRT/parsers/caffe/caffeWeightFactory/caffeWeightFactory.h,caffeWeightFactory.cpp源碼研讀一

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值