template<typenameT>voidprint_tensor_data(torch::Tensor tsor,std::string name,int64_t start,int64_t end){
std::cout<<name<<".shape: ";
tsor.print();
std::cout<<"the outdata of "<<name<<" is: "<<std::endl;
tsor = tsor.reshape({-1});int64_t len = tsor.sizes()[0];if((end-start)> len){
end = len + start;}assert((end-start)>0&&"start or end is not right!");assert((end-start)<=len &&"start or end is not right!");assert(start<len &&"start is not right!");assert(end<=len &&"end is not right!");for(int i=start;i<end;i++){
T s = tsor[i].item<T>();
std::cout<<s<<" ";}
std::cout<<std::endl<<std::endl;}
2. 读取txt生成tensor
template<typenameT>
std::vector<T>InputData_To_Vector(const std::string &path){
std::vector<T> p;
std::ifstream infile(path);assert(infile.is_open()&&"Unable to open txt file. please check if the .txt file path is right!");
T number;
std::string s;while(getline(infile, s)){
std::istringstream is(s);
T d;while(!is.eof()){
is >> d;
p.push_back(d);}
s.clear();}
infile.close();return p;}
std::vector data = InputData_To_Vector<float>("data.txt");auto ten = torch::tensor(data).reshape({-1}).toType(torch::kFloat32);
3. 保存tensor到txt
template<typenameT>voidsave_tensor_2txt(torch::Tensor tsor,const std::string &path){
std::ofstream outfile;
outfile.open(path);assert(outfile &&"failed to open the file!");
tsor = tsor.reshape({-1});for(int i=0;i<tsor.sizes()[0];i++){
T s = tsor[i].item<T>();
outfile<<s<<"\n";}
outfile.close();}
4. python保存权重文件wts
if save_wts:
f =open('model.wts','w')
f.write('{}\n'.format(len(my_model.state_dict().keys())))for k, v in my_model.state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {}'.format(k,len(vr)))print(k,len(vr))for vv in vr:
f.write(' ')
f.write(struct.pack('>f',float(vv)).hex())
f.write('\n')return
5. load权重文件wts
static std::map<std::string, std::vector<float>>loadWeights(const std::string &file){
std::cout <<"Loading weights: "<< file << std::endl;
std::map<std::string, std::vector<float>> weightMap;// Open weights file
std::ifstream input(file);assert(input.is_open()&&"Unable to load weight file. please check if the .wts file path is right!");// Read number of weight blobsint count;
input >> count;assert(count >0&&"Invalid weight map file.");
std::vector<float>data;while(count--){int size;// Read name and type of blob
std::string name;
input >> name >> std::dec >> size;unsignedint value;float value_f;for(int64_t x =0, y = size; x < y;++x){
input >> std::hex >>value;
value_f =reinterpret_cast<float&>(value);
data.push_back(value_f);}
weightMap[name]= data;
data.clear();}
input.close();return weightMap;}