从零构建深度学习推理框架-2 从CSV文件初始化Tensor

CSV格式

概念

CSV(逗号分隔值)文件是一种特殊的文件类型,可在 Excel 中创建或编辑。CSV文件采用逗号分隔的形式来存储文本和数字信息,总体来说,这种形式的文件格式具有扩展性好,移植性强的特点。

目前许多主流程序采用CSV文件作为数据导入导出的中间格式,例如MySQL数据库可以从CSV文件中导入数据,GMail联系人可以导出到CSV文件,然后将其导入到Outlook中。

pytorch和KuiperInfer中间数据对比文件,用于读取Pytorch复杂,且多变维度的输出,来对比我们的推理结果。

pytorch输出(被定位到csv文件)中读取,KuiperInfer读取,然后再对比.

过程是这样的:

pytorch读取权重文件--》infer--》numpy--》再保存到本地,格式为csv -->读取并初始化 tensor

kuiper 读取权重文件-->保存为tensor -->

绿色背景就是我们今天要去做的。

代码:

其实这个也不会很难:

std::shared_ptr<Tensor<float >> CSVDataLoader::LoadData(const std::string &file_path, char split_char) {
  CHECK(!file_path.empty()) << "File path is empty!";
  std::ifstream in(file_path);
  CHECK(in.is_open() && in.good()) << "File open failed! " << file_path;

  std::string line_str;
  std::stringstream line_stream;

  const auto &[rows, cols] = CSVDataLoader::GetMatrixSize(in, split_char);
  std::shared_ptr<Tensor<float>> input_tensor = std::make_shared<Tensor<float>>(1, rows, cols);
  arma::fmat &data = input_tensor->at(0);

  size_t row = 0;
  while (in.good()) {
    std::getline(in, line_str);
    if (line_str.empty()) {
      break;
    }

    std::string token;
    line_stream.clear();
    line_stream.str(line_str);

    size_t col = 0;
    while (line_stream.good()) {
      std::getline(line_stream, token, split_char);
      try {
        data.at(row, col) = std::stof(token);
      }
      catch (std::exception &e) {
        LOG(ERROR) << "Parse CSV File meet error: " << e.what();
        continue;
      }
      col += 1;
      CHECK(col <= cols) << "There are excessive elements on the column";
    }

    row += 1;
    CHECK(row <= rows) << "There are excessive elements on the row";
  }
  return input_tensor;
}

核心区段:

  const auto &[rows, cols] = CSVDataLoader::GetMatrixSize(in, split_char);
  std::shared_ptr<Tensor<float>> input_tensor = std::make_shared<Tensor<float>>(1, rows, cols);
  arma::fmat &data = input_tensor->at(0);

确认rows和cols,创建空tensor

while (line_stream.good()) {
      std::getline(line_stream, token, split_char);
      try {
        data.at(row, col) = std::stof(token);
      }
      catch (std::exception &e) {
        LOG(ERROR) << "Parse CSV File meet error: " << e.what();
        continue;
      }
      col += 1;
      CHECK(col <= cols) << "There are excessive elements on the column";
    }

    row += 1;
    CHECK(row <= rows) << "There are excessive elements on the row";
  }
  return input_tensor;

迭代,读进来赋值,返回Tensor

这里唯一需要注意的就是有表头的情况:

try {
        //todo 补充
        // 能够读取到第一行的csv列名,并存放在headers中
        // 能够读取到第二行之后的csv数据,并相应放置在data变量的row,col位置中
        
        if (row == 0)
        {
                  headers.push_back(token);
        }
        else{
        data.at(row-1, col) = std::stof(token);
        }
      }

在try中要先将表头传到header里,再将其余的数组读到data中,但记住data中的row要-1,因为要扣除表头的一行。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值