在开发中解析 NPY 文件并转换为 Tensor 的实践
在完善工作台的抠图功能时,我们需要从服务器获取一个 NPY 文件,并将其数据转换为 Tensor,以便进行后续的模型推理。在实现这个功能的过程中,遇到了一些BUG。经过研究,终于找到了一个有效的解决方案。本文将详细介绍我们是如何一步步解析 NPY 文件并将其转换为 Tensor 的。
问题描述
在机器学习和深度学习领域,NPY 文件是一种常见的用于存储多维数组的数据格式。我们需要在前端从服务器获取一个 NPY 文件,并将其内容转换为 Tensor,以便在浏览器中使用 ONNX Runtime Web 进行模型推理。然而,直接解析 NPY 文件并不是一件简单的事情,因为它包含了文件头和数据部分,文件头中存储了数据的元信息。
解决方案
经过研究,我们决定手动解析 NPY 文件的文件头,并提取其中的数据部分。以下是我们实现这一功能的详细步骤:
- 获取 NPY 文件的二进制数据:通过
fetch
请求从服务器获取 NPY 文件,并将其转换为ArrayBuffer
。 - 解析 NPY 文件头部:从二进制数据中提取文件头部信息,包括magic_number、版本号和头部长度等。
- 提取数据部分:根据头部信息确定数据部分的起始位置,并将其转换为
Float32Array
。 - 创建 Tensor:使用提取的数据创建一个 Tensor,以便进行后续的模型推理。
详细实现
下面是完整的代码实现:
segRequest.then(async (segResponse) => {
const arrayBuffer = await segResponse.arrayBuffer();
const uint8Array = new Uint8Array(arrayBuffer);
// 解析 npy 文件头部
const magic = String.fromCharCode.apply(
null,
Array.from(uint8Array.subarray(0, 6))
);
if (magic !== "\x93NUMPY") {
throw new Error("Invalid npy file");
}
const headerLength = uint8Array[8] + uint8Array[9] * 256;
const headerStr = String.fromCharCode.apply(
null,
Array.from(uint8Array.subarray(10, 10 + headerLength))
);
// 提取数据部分
const dataOffset = 10 + headerLength;
const data = new Float32Array(arrayBuffer, dataOffset);
const lowResTensor = new Tensor("float32", data, [1, 256, 64, 64]);
handleSegModelResults({
tensor: lowResTensor,
});
return data;
});
代码解析
-
获取 NPY 文件的二进制数据:
const arrayBuffer = await segResponse.arrayBuffer(); const uint8Array = new Uint8Array(arrayBuffer);
-
解析 NPY 文件头部:
-
首先,我们从文件的前 6 个字节中提取magic number,并验证文件是否为有效的 NPY 文件。
const magic = String.fromCharCode.apply( null, Array.from(uint8Array.subarray(0, 6)) ); if (magic !== "\x93NUMPY") { throw new Error("Invalid npy file"); }
-
然后,我们读取头部长度,并从文件中提取头部字符串。
const headerLength = uint8Array[8] + uint8Array[9] * 256; const headerStr = String.fromCharCode.apply( null, Array.from(uint8Array.subarray(10, 10 + headerLength)) );
-
-
提取数据部分:
-
根据头部长度确定数据部分的起始位置,并将其转换为
Float32Array
const dataOffset = 10 + headerLength; const data = new Float32Array(arrayBuffer, dataOffset);
-
-
创建 Tensor:
-
使用提取的数据创建一个 Tensor,并传递给处理函数。
const lowResTensor = new Tensor("float32", data, [1, 256, 64, 64]); handleSegModelResults({ tensor: lowResTensor, });
-
总结
通过手动解析文件头部并提取数据部分,我们成功地实现了解析 NPY 文件并将其转换为 Tensor 。希望这篇博客能对遇到类似问题的开发者有所帮助。