如何使用C#读取safetensors扩展名的大模型文件

目录

safetensors是一种十分常见的大模型权重文件。这种模型文件最初由Hugging Face提出,目前广泛用于各类深度学习场景存储模型的权重。在Python环境下,有相应的包可以直接读取safetensors文件的权重内容,用户无需过多关注该模型文件的结构,几乎可以只靠一行代码实现。但是在C#环境下,还没有广泛使用/通用的safetensors文件的读取工具,这给C#开发者造成了不小的困扰。一些开发者会将safetensors文件转换为onnx再开发,但是这样会需要一步额外转化,而且转化时也没脱离python开发环境,所需要的依赖并不少。

为了解决C#下使用Safetensors文件的问题,故开发了相关功能。本文介绍了如何使用C#直接读取safetensors文件的内容及其各项的权重,以方便在C#环境下使用各种不同的深度学习框架,例如Tochsharp、GGMLSharp等,经过简单的处理就可以直接加载Safetensors文件权重。

Safetensors文件的结构及读取思路

Safetensors文件可以认为是一种binary文件,在C#下可以采用流的方式进行读取。

  1. Safetensors文件的结构可以大致分为头部长度+头部+权重内容。其中头部长度占用8个byte,可以转化成一个int64,以表示头部的总长度。
  2. 头部本身是json结构的,可以使用相关的库进行读取。其中每一个元素都包含了tensor的名称、类型、偏移量、形状;
  3. json结构之后是tensor数据的存储部分,需要借助前面的定义来读取;
  4. 可以建立一个新的类,用于存储读取到的tensor的信息;
  5. 为了能够节省内存/显存,提高tensors结构体的读取速度,可以先只读取tensors的结构,在使用tensor的数据时,才进行读取其值;

读取方法

使用C#对Safetensors文件读取,为了减少内存的使用,提高读取速度,建议使用流的方式。

读取Header的长度

Safetensors文件的开头8个字节标识了Header部分的长度,因此可以直接读取,按int64类型转化成整数,这一部分就是这整个Header的长度。

byte[] headerBlock = new byte[8];
stream.Read(headerBlock, 0, 8);
long headerSize = BitConverter.ToInt64(headerBlock, 0);

读取Header的内容

Header的内容可以看作是一个json文件,里面有名称、数据类型、在文件中的偏移量、形状等信息。建立一个新的类来进行读取和存储。

// Read the header, header file is a json file
byte[] headerBytes = new byte[headerSize];
stream.Read(headerBytes, 0, (int)headerSize);

string header = Encoding.UTF8.GetString(headerBytes);
long bodyPosition = stream.Position;
JToken token = JToken.Parse(header);

List<Tensor> tensors = new List<Tensor>();
foreach (var sub in token.ToObject<Dictionary<string, JToken>>())
{
	Dictionary<string, JToken> value = sub.Value.ToObject<Dictionary<string, JToken>>();
	value.TryGetValue("data_offsets", out JToken offsets);
	value.TryGetValue("dtype", out JToken dtype);
	value.TryGetValue("shape", out JToken shape);

	ulong[] offsetArray = offsets?.ToObject<ulong[]>();
	if (null == offsetArray)
	{
		continue;
	}
	long[] shapeArray = shape.ToObject<long[]>();
	if (shapeArray.Length < 1)
	{
		shapeArray = new long[] { 1 };
	}
	GGmlType ggml_type = GGmlType.GGML_TYPE_F32;
	switch (dtype.ToString())
	{
		case "I8": ggml_type = GGmlType.GGML_TYPE_I8; break;
		case "I16": ggml_type = GGmlType.GGML_TYPE_I16; break;
		case "I32": ggml_type = GGmlType.GGML_TYPE_I32; break;
		case "I64": ggml_type = GGmlType.GGML_TYPE_I64; break;
		case "BF16": ggml_type = GGmlType.GGML_TYPE_BF16; break;
		case "F16": ggml_type = GGmlType.GGML_TYPE_F16; break;
		case "F32": ggml_type = GGmlType.GGML_TYPE_F32; break;
		case "F64": ggml_type = GGmlType.GGML_TYPE_F64; break;
		case "U8":
		case "U16":
		case "U32":
		case "U64":
		case "BOOL":
		case "F8_E4M3":
		case "F8_E5M2": break;
	}

	Tensor tensor = new Tensor
	{
		Name = sub.Key,
		Type = ggml_type,
		Shape = shapeArray.ToList(),
		Offset = offsetArray.ToList(),
		FileName = inputFileName,
		BodyPosition = bodyPosition
	};

	tensors.Add(tensor);
}

其中Tensor类的定义如下:

public class Tensor
{
	public string Name { get; set; }
	public Structs.GGmlType Type { get; set; } = Structs.GGmlType.GGML_TYPE_F16;
	public List<long> Shape { get; set; } = new List<long>();
	public List<ulong> Stride { get; set; } = new List<ulong>();
	public string DataNameInZipFile { get; set; }
	public string FileName { get; set; }
	public List<ulong> Offset { get; set; } = new List<ulong>();
	public long BodyPosition { get; set; }

}

因为该代码最初是给C#使用ggml而写,所以tensor的类型使用了ggml中的精度类型,如果有需要可以根据自己的平台修改。

读取tensor的权重值

当获取到tensor的结构后就可以读这一部分了,按照流的方式读取,读取时全部按byte读取。tensor在声明时标识了自己的类型,这会在各个平台计算时自己转化。

private byte[] ReadByteFromFile(string inputFileName, long bodyPosition, long offset, int size)
{
	using (FileStream stream = File.OpenRead(inputFileName))
	{
		stream.Seek(bodyPosition + offset, SeekOrigin.Begin);
		byte[] dest = new byte[size];
		stream.Read(dest, 0, size);
		return dest;
	}
}

public byte[] ReadByteFromFile(Tensor tensor)
{
	string inputFileName = tensor.FileName;
	long bodyPosition = tensor.BodyPosition;
	ulong offset = tensor.Offset[0];
	int size = (int)(tensor.Offset[1] - tensor.Offset[0]);
	return ReadByteFromFile(inputFileName, bodyPosition, (long)offset, size);
}

总结

C#读取Safetensors文件并不算困难。只是因为使用C#搞深度学习的人并不多,相关功能实现并不普及。撰写本文是希望能够帮助更多喜欢使用C#开发深度学习项目的爱好者更容易实现自己的项目。
该项目的完整代码可以从C#读取safetensors文件方法下载。

该模块来自我正在开发的GGMLSharp项目,如果喜欢该项目,请在GitHub上送我一颗小星星。

  • 17
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值