【C++】“.wts”权重文件内容读取详解

为方便大家理解加载“.wts”权重文件的过程,本文通过示例对加载的过程进行详细解读,包括如何读取,以什么形式读取,读取后数据是什么形式等。
此处使用的是“.pt”转“.wts”,再转“.engine”,进行tensorrt加速过程中的“.wts”文件读取的过程。

加载权重部分代码

std::map<std::string, nvinfer1::Weights> loadWeights(const std::string file){
    std::cout << "Loading weights: " << file << std::endl;
    std::map<std::string, nvinfer1::Weights> WeightMap;
    //file文件,第一行是总层个数,第二行到最后是相应的层名称和权重。每一行先是名称,然后该行数量,最后是数值
    std::ifstream input(file);
    assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!");
    //定义一个 int32_t类型数值
    int32_t count;
    //从input中读取一个int32_t的数值给count,是层的个数,即权重的个数
    input>>count ;
    assert(count > 0 && "Invalid weight map file.");

    while(count--){
    	//定义wt,包括里边包含的内容,从下面可知是精度类型、权重值和权重个数
        nvinfer1::Weights wt{nvinfer1::DataType::kFLOAT, nullptr, 0};
        uint32_t size;

        std::string name;
        //十进制
        input >> name >> std::dec >> size;
        wt.type = nvinfer1::DataType::kFLOAT;
        //reinterpret_cast<uint32_t*> 这是一个类型转换。它将malloc返回的void*类型的指针转换为uint32_t*类型的指针。
        //因为C++不允许直接将一个类型的指针赋值给另一个类型,除非进行显式类型转换。
        //uint32_t* val; -先定义一个指向uint32_t类型的指针变量val。
        uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
        for(uint32_t x = 0, y = size; x < y; x++){
            //输出格式为16进制
            input >> std::hex >> val[x];
        }
        //赋值和保存
        wt.values = val;
        wt.count = size;
        WeightMap[name] = wt;
    }
    return WeightMap;
}

为方便理解其中的内容,进行摘选,运用示例进行详解。

#include <fstream>
#include <iostream>
#include <string>

int main() {
    std::ifstream input("example.txt"); // 替换为你的文件路径
    if (!input.is_open()) {
        std::cerr << "Unable to open file!" << std::endl;
        return 1; // 返回非零错误代码
    }


    //定义一个 int32_t类型数值
    int32_t count;
    //从input中读取一个int32_t的数值给count,是层的个数,即权重的个数
    input >> count;
    std::cout << "count:     " << count<< std::endl;

    while (count--) {
        
        uint32_t size;

        std::string name;
        //十进制
        input >> name >> std::dec >> size;
        std::cout << "name:     " << name << std::endl;
        std::cout << "size:     " << size << std::endl;
        //reinterpret_cast<uint32_t*> 这是一个类型转换。它将malloc返回的void*类型的指针转换为uint32_t*类型的指针。
        //因为C++不允许直接将一个类型的指针赋值给另一个类型,除非进行显式类型转换。
        //uint32_t* val; -先定义一个指向uint32_t类型的指针变量val。
        uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
        for (uint32_t x = 0, y = size; x < y; x++) {
            //输出格式为16进制
            input >> std::hex >> val[x];
            std::cout <<"val:   " << val << std::endl;
        }
        //std::cout<<"name:     " <<name<<"    " << "val:   " << val << std::endl;
        
    }


    input.close();
    return 0; // 正常退出
}

“example.txt”文件中的内容。
在这里插入图片描述
运行程序,部分输出结果。

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

木彳

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值