libtorch学习笔记(1)

常用函数

1. print_tensor_data 打印tensor的值

template <typename T>
void print_tensor_data(torch::Tensor tsor,std::string name, int64_t start, int64_t end)
{
    std::cout<<name<<".shape: ";
    tsor.print();
    std::cout<<"the outdata of "<<name<<" is: "<<std::endl;
    tsor = tsor.reshape({-1});
    int64_t len = tsor.sizes()[0];
    if((end-start)> len)
    {
        end = len + start;
    }
    assert((end-start)>0 && "start or end is not right!");
    assert((end-start)<=len && "start or end is not right!");
    assert(start<len && "start is not right!");
    assert(end<=len && "end is not right!");

    for(int i=start;i<end;i++)
    {
        T s = tsor[i].item<T>();
        std::cout<<s<<" ";
    }
    std::cout<<std::endl<<std::endl;
}

2. 读取txt生成tensor

template <typename T>
std::vector<T> InputData_To_Vector(const std::string &path)
{
    std::vector<T> p;
    std::ifstream infile(path);
    assert(infile.is_open() && "Unable to open txt file. please check if the .txt file path is right!");
    T number;
    std::string s;
	while (getline(infile, s)) {
		std::istringstream is(s);
		T d;
		while (!is.eof()) {
			is >> d;
			p.push_back(d);
        }
        s.clear();
    }
    infile.close();
    return p;
}
std::vector data = InputData_To_Vector<float>("data.txt");
auto ten = torch::tensor(data).reshape({-1}).toType(torch::kFloat32);

3. 保存tensor到txt

template <typename T>
void save_tensor_2txt(torch::Tensor tsor, const std::string &path)
{
    std::ofstream outfile;
    outfile.open(path);
    assert(outfile && "failed to open the file!");    
    tsor = tsor.reshape({-1});
    for(int i=0;i<tsor.sizes()[0];i++)
    {
        T s = tsor[i].item<T>();
        outfile<<s<<"\n"; 
    }
    outfile.close();
}

4. python保存权重文件wts

    if save_wts:
        f = open('model.wts', 'w')
        f.write('{}\n'.format(len(my_model.state_dict().keys())))
        for k, v in my_model.state_dict().items():
            vr = v.reshape(-1).cpu().numpy()
            f.write('{} {}'.format(k, len(vr)))
            print(k, len(vr))
            for vv in vr:
                f.write(' ')
                f.write(struct.pack('>f', float(vv)).hex())
            f.write('\n')
        return

5. load权重文件wts

static std::map<std::string, std::vector<float>> loadWeights(const std::string &file) {
    std::cout << "Loading weights: " << file << std::endl;
    std::map<std::string, std::vector<float>> weightMap;

    // Open weights file
    std::ifstream input(file);
    assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!");

    // Read number of weight blobs
    int count;
    input >> count;
    assert(count > 0 && "Invalid weight map file.");
    std::vector<float>data;
    while (count--)
    {
        int size;
        // Read name and type of blob
        std::string name;
        input >> name >> std::dec >> size;
        unsigned int value;
        float value_f;
        for (int64_t x = 0, y = size; x < y; ++x)
        {
            input >> std::hex >>value;

            value_f = reinterpret_cast<float&>(value);
            data.push_back(value_f);
        }
        weightMap[name] = data;
        data.clear();
    }
    input.close();
    return weightMap;
}

6. GetFileNames 获取文件夹内的所有文件

static void GetFileNames(string path,vector<string>& filenames)
{
    DIR *pDir;
    struct dirent* ptr;
    if(!(pDir = opendir(path.c_str()))){
        cout<<"Folder doesn't Exist!"<<endl;
        return;
    }
    while((ptr = readdir(pDir))!=0) {
        if (strcmp(ptr->d_name, ".") != 0 && strcmp(ptr->d_name, "..") != 0){
            filenames.push_back(path + "/" + ptr->d_name);
    }
    }
    closedir(pDir);
}

7. 读取bin点云

static bool readBIN(const char *path,std::vector<float> &v, int32_t max)
{
    v.clear();
	std::ifstream infile(path, std::ios::in | std::ios::binary);
    v.resize(4*max);
    infile.read((char *)&v.front(), v.size()*sizeof(float));
	return true;
}
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

hanqu3456

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值