为方便大家理解加载“.wts”权重文件的过程,本文通过示例对加载的过程进行详细解读,包括如何读取,以什么形式读取,读取后数据是什么形式等。
此处使用的是“.pt”转“.wts”,再转“.engine”,进行tensorrt加速过程中的“.wts”文件读取的过程。
加载权重部分代码
std::map<std::string, nvinfer1::Weights> loadWeights(const std::string file){
std::cout << "Loading weights: " << file << std::endl;
std::map<std::string, nvinfer1::Weights> WeightMap;
//file文件,第一行是总层个数,第二行到最后是相应的层名称和权重。每一行先是名称,然后该行数量,最后是数值
std::ifstream input(file);
assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!");
//定义一个 int32_t类型数值
int32_t count;
//从input中读取一个int32_t的数值给count,是层的个数,即权重的个数
input>>count ;
assert(count > 0 && "Invalid weight map file.");
while(count--){
//定义wt,包括里边包含的内容,从下面可知是精度类型、权重值和权重个数
nvinfer1::Weights wt{nvinfer1::DataType::kFLOAT, nullptr, 0};
uint32_t size;
std::string name;
//十进制
input >> name >> std::dec >> size;
wt.type = nvinfer1::DataType::kFLOAT;
//reinterpret_cast<uint32_t*> 这是一个类型转换。它将malloc返回的void*类型的指针转换为uint32_t*类型的指针。
//因为C++不允许直接将一个类型的指针赋值给另一个类型,除非进行显式类型转换。
//uint32_t* val; -先定义一个指向uint32_t类型的指针变量val。
uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
for(uint32_t x = 0, y = size; x < y; x++){
//输出格式为16进制
input >> std::hex >> val[x];
}
//赋值和保存
wt.values = val;
wt.count = size;
WeightMap[name] = wt;
}
return WeightMap;
}
为方便理解其中的内容,进行摘选,运用示例进行详解。
#include <fstream>
#include <iostream>
#include <string>
int main() {
std::ifstream input("example.txt"); // 替换为你的文件路径
if (!input.is_open()) {
std::cerr << "Unable to open file!" << std::endl;
return 1; // 返回非零错误代码
}
//定义一个 int32_t类型数值
int32_t count;
//从input中读取一个int32_t的数值给count,是层的个数,即权重的个数
input >> count;
std::cout << "count: " << count<< std::endl;
while (count--) {
uint32_t size;
std::string name;
//十进制
input >> name >> std::dec >> size;
std::cout << "name: " << name << std::endl;
std::cout << "size: " << size << std::endl;
//reinterpret_cast<uint32_t*> 这是一个类型转换。它将malloc返回的void*类型的指针转换为uint32_t*类型的指针。
//因为C++不允许直接将一个类型的指针赋值给另一个类型,除非进行显式类型转换。
//uint32_t* val; -先定义一个指向uint32_t类型的指针变量val。
uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
for (uint32_t x = 0, y = size; x < y; x++) {
//输出格式为16进制
input >> std::hex >> val[x];
std::cout <<"val: " << val << std::endl;
}
//std::cout<<"name: " <<name<<" " << "val: " << val << std::endl;
}
input.close();
return 0; // 正常退出
}
“example.txt”文件中的内容。
运行程序,部分输出结果。