如何使用C#读取pickle类型的大模型文件

背景

前几天发布了如何使用C#读取safetensors扩展名的大模型文件的文章,介绍了C#加载Safetensors文件的方法。在深度学习领域,除了Safetensors类型的权重文件外,Pickle类型的权重文件也十分流行,这种类型的权重文件一般以.pt,.pth,.pkl,.bin为扩展名。Pickle类型的文件大多是Python开发环境生成的,在Python下读写十分容易,但是除Python环境下直接加载还是十分困难的。尤其在C#环境下,还没有广泛使用/通用的pickle类型权重文件的读取工具。一些开发者会像处理safetensors文件一样,将其转换为onnx再开发,但是这样会需要一步额外转化,而且转化时也没脱离python开发环境,所需要的依赖并不少。

为了解决C#下使用Pickle类型文件的问题,故开发了相关功能。本文介绍了如何使用C#直接读取Pickle类型文件的内容及其各项的权重,以方便在C#环境下使用各种不同的深度学习框架,例如Tochsharp、GGMLSharp等,经过简单的处理就可以直接加载Pickle文件权重。

Pickle文件的结构及读取思路

Pickle文件可以认为是一种压缩文件,使用压缩工具可以直接打开。
我们可以使用压缩工具先来观察一下Pickle文件的结构
Pickle文件的内容
从上图可以看到data.pkl类型的文件有权重信息的文件,version文件包含了版本信息,data文件夹保存各个Tensor数据。我们的重点就在如何解读data.pkl文件上。

  1. 读取data.pkl文件,并以二进制的方式对其进行解析;
  2. 可以建立一个新的类,用于存储读取到的tensor的信息;
  3. 读取data文件夹内的tensor的值(data),并加载进tensor中;
  4. 为了能够节省内存/显存,提高tensors结构体的读取速度,可以先只读取tensors的结构,在使用tensor的数据时,才进行读取其值;

读取方法

使用C#以压缩文件的方式对Pickle文件整体读取,为了减少内存的使用,提高读取速度,建议先读取结构,在需要使用tensor的值时再读取其数据。

以压缩文件的方式加载Pickle类型文件

可以使用.net自带的ZipArchive方法加载压缩文件。

private ZipArchive zip;
private ReadOnlyCollection<ZipArchiveEntry> entries;
zip = ZipFile.OpenRead(fileName);
entries = zip.Entries;

读取Header的内容

Header的内容可以看作是一个二进制文件,里面有名称、数据类型、在文件中的偏移量、形状等信息。建立一个新的类来进行读取和存储。
例如一份data.pkl读取后的结果如下,以

 6: X        BINUNICODE 'model.layers.24.self_attn.q_proj.weight'

为例进行解释:
6代表该文件的第6个字节,X代表该字节对应的内容,该字节内容若以ASCII码标识就是’X’这个字符(注意区分大小写),X在Pickle文件中代表接下来的内容时Binunicode(字符类型),其内容是’model.layers.24.self_attn.q_proj.weight’
再下一行
50: q BINPUT 1
代表下一个完整内容从第50个byte开始了

    0: \x80 PROTO      2
    2: }    EMPTY_DICT
    3: q    BINPUT     0
    5: (    MARK
    6: X        BINUNICODE 'model.layers.24.self_attn.q_proj.weight'
   50: q        BINPUT     1
   52: c        GLOBAL     'torch._utils _rebuild_tensor_v2'
   85: q        BINPUT     2
   87: (        MARK
   88: (            MARK
   89: X                BINUNICODE 'storage'
  101: q                BINPUT     3
  103: c                GLOBAL     'torch HalfStorage'
  122: q                BINPUT     4
  124: X                BINUNICODE '0'
  130: q                BINPUT     5
  132: X                BINUNICODE 'cpu'
  140: q                BINPUT     6
  142: J                BININT     16777216
  147: t                TUPLE      (MARK at 88)
  148: q            BINPUT     7
  150: Q            BINPERSID
  151: K            BININT1    0
  153: M            BININT2    4096
  156: M            BININT2    4096
  159: \x86         TUPLE2
  160: q            BINPUT     8
  162: M            BININT2    4096
  165: K            BININT1    1
  167: \x86         TUPLE2
  168: q            BINPUT     9
  170: \x89         NEWFALSE
  171: c            GLOBAL     'collections OrderedDict'
  196: q            BINPUT     10
  198: )            EMPTY_TUPLE
  199: R            REDUCE
  200: q            BINPUT     11
  202: t            TUPLE      (MARK at 87)
  203: q        BINPUT     12
  205: R        REDUCE
  206: q        BINPUT     13
  208: X        BINUNICODE 'model.layers.24.self_attn.k_proj.weight'
  252: q        BINPUT     14
  254: h        BINGET     2
  256: (        MARK
  257: (            MARK
  258: h                BINGET     3
  260: h                BINGET     4
  262: X                BINUNICODE '1'
  268: q                BINPUT     15
  270: h                BINGET     6
  272: J                BININT     16777216
  277: t                TUPLE      (MARK at 257)
  278: q            BINPUT     16
  280: Q            BINPERSID
  281: K            BININT1    0
  283: M            BININT2    4096
  286: M            BININT2    4096
  289: \x86         TUPLE2
  290: q            BINPUT     17
  292: M            BININT2    4096
  295: K            BININT1    1
  297: \x86         TUPLE2
  298: q            BINPUT     18
  300: \x89         NEWFALSE
  301: h            BINGET     10
  303: )            EMPTY_TUPLE
  304: R            REDUCE
  305: q            BINPUT     19
  307: t            TUPLE      (MARK at 256)
  308: q        BINPUT     20
  310: R        REDUCE
  311: q        BINPUT     21
  313: X        BINUNICODE 'model.layers.24.self_attn.v_proj.weight'

  ......
10115: r        LONG_BINPUT 685
10120: R        REDUCE
10121: r        LONG_BINPUT 686
10126: X        BINUNICODE 'lm_head.weight'
10145: r        LONG_BINPUT 687
10150: h        BINGET     2
10152: (        MARK
10153: (            MARK
10154: h                BINGET     3
10156: h                BINGET     4
10158: X                BINUNICODE '85'
10165: r                LONG_BINPUT 688
10170: h                BINGET     6
10172: J                BININT     131072000
10177: t                TUPLE      (MARK at 10153)
10178: r            LONG_BINPUT 689
10183: Q            BINPERSID
10184: K            BININT1    0
10186: M            BININT2    32000
10189: M            BININT2    4096
10192: \x86         TUPLE2
10193: r            LONG_BINPUT 690
10198: M            BININT2    4096
10201: K            BININT1    1
10203: \x86         TUPLE2
10204: r            LONG_BINPUT 691
10209: \x89         NEWFALSE
10210: h            BINGET     10
10212: )            EMPTY_TUPLE
10213: R            REDUCE
10214: r            LONG_BINPUT 692
10219: t            TUPLE      (MARK at 10152)
10220: r        LONG_BINPUT 693
10225: R        REDUCE
10226: r        LONG_BINPUT 694
10231: u        SETITEMS   (MARK at 5)
10232: .    STOP

对data.pkl文件读取可以按byte顺序进行,一直读到该文件结束。读取的代码如下:

		public List<Tensor> ReadTensorsInfoFromFile(string fileName)
		{
			List<Tensor> tensors = new List<Tensor>();

			zip = ZipFile.OpenRead(fileName);
			entries = zip.Entries;
			ZipArchiveEntry headerEntry = entries.First(e => e.Name == "data.pkl");
			byte[] headerBytes = new byte[headerEntry.Length];
			// Header is always small enough to fit in memory, so we can read it all at once
			using (Stream stream = headerEntry.Open())
			{
				stream.Read(headerBytes, 0, headerBytes.Length);
			}

			if (headerBytes[0] != 0x80 || headerBytes[1] != 0x02)
			{
				throw new ArgumentException("Not a valid pickle file");
			}

			int index = 1;
			bool finished = false;
			bool readStrides = false;
			bool binPersid = false;

			Tensor tensor = new Tensor() { FileName = fileName, Offset = { 0 } };

			int deepth = 0;

			Dictionary<int, string> BinPut = new Dictionary<int, string>();

			while (index < headerBytes.Length && !finished)
			{
				byte opcode = headerBytes[index];
				switch (opcode)
				{
					case (byte)'}':  // EMPTY_DICT     = b'}'   # push empty dict
						break;
					case (byte)']':  // EMPTY_LIST     = b']'   # push empty list
						break;
					// skip unused sections
					case (byte)'h':  // BINGET         = b'h'   #   "    "    "    "   "   "  ;   "    " 1-byte arg
						{
							int id = headerBytes[index + 1];
							BinPut.TryGetValue(id, out string precision);
							if (precision != null)
							{
								if (precision.Contains("FloatStorage"))
								{
									tensor.Type = Structs.GGmlType.GGML_TYPE_F32;
								}
								else if (precision.Contains("HalfStorage"))
								{
									tensor.Type = Structs.GGmlType.GGML_TYPE_F16;
								}
								else if (precision.Contains("BFloat16Storage"))
								{
									tensor.Type = Structs.GGmlType.GGML_TYPE_BF16;
								}
							}
							index++;
							break;
						}
					case (byte)'q':  // BINPUT         = b'q'   #   "     "    "   "   " ;   "    " 1-byte arg
						{
							index++;
							break;
						}
					case (byte)'Q':  // BINPERSID      = b'Q'   #  "       "         "  ;  "  "   "     "  stack
						binPersid = true;
						break;
					case (byte)'r':  // LONG_BINPUT    = b'r'   #   "     "    "   "   " ;   "    " 4-byte arg
						index += 4;
						break;
					case 0x95:  // FRAME            = b'\x95'  # indicate the beginning of a new frame
						index += 8;
						break;
					case 0x94:  // MEMOIZE          = b'\x94'  # store top of the stack in memo
						break;
					case (byte)'(':  // MARK           = b'('   # push special markobject on stack
						deepth++;
						break;
					case (byte)'K':  // BININT1        = b'K'   # push 1-byte unsigned int
						{
							int value = headerBytes[index + 1];
							index++;

							if (deepth > 1 && value != 0 && binPersid)
							{
								if (readStrides)
								{
									//tensor.Stride.Add((ulong)value);
									tensor.Stride.Add((ulong)value);
								}
								else
								{
									tensor.Shape.Add(value);
								}
							}
						}
						break;
					case (byte)'M':  // BININT2        = b'M'   # push 2-byte unsigned int
						{
							UInt16 value = BitConverter.ToUInt16(headerBytes, index + 1);
							index += 2;

							if (deepth > 1 && value != 0 && binPersid)
							{
								if (readStrides)
								{
									tensor.Stride.Add(value);
								}
								else
								{
									tensor.Shape.Add(value);
								}
							}

						}
						break;
					case (byte)'J':  // BININT         = b'J'   # push four-byte signed int
						{
							int value = BitConverter.ToInt32(headerBytes, index + 1);
							//int value = headerBytes[index + 4] << 24 + headerBytes[index + 3] << 16 + headerBytes[index + 2] << 8 + headerBytes[index + 1];
							index += 4;

							if (deepth > 1 && value != 0 && binPersid)
							{
								if (readStrides)
								{
									tensor.Stride.Add((ulong)value);
								}
								else
								{
									tensor.Shape.Add(value);
								}
							}
						}
						break;

					case (byte)'X':  // BINUNICODE     = b'X'   #   "     "       "  ; counted UTF-8 string argument
						{
							int length = headerBytes[index + 1];
							int start = index + 5;
							byte module = headerBytes[index + 1];
							string name = System.Text.Encoding.UTF8.GetString(headerBytes, start, length);
							index = index + 4 + length;

							if (deepth == 1)
							{
								tensor.Name = name;
							}
							else if (deepth == 3)
							{
								if ("cpu" != name && !name.Contains("cuda"))
								{
									tensor.DataNameInZipFile = name;
								}
							}
						}
						break;
					case 0x8C:  // SHORT_BINUNICODE = b'\x8c'  # push short string; UTF-8 length < 256 bytes
						{

						}
						break;
					case (byte)'c':  // GLOBAL         = b'c'   # push self.find_class(modname, name); 2 string args
						{
							int start = index + 1;
							while (headerBytes[index + 1] != (byte)'q')
							{
								index++;
							}
							int length = index - start + 1;

							string global = System.Text.Encoding.UTF8.GetString(headerBytes, start, length);

							// precision is stored in the global variable
							// next tensor will read the precision
							// so we can set the Type here

							BinPut.Add(headerBytes[index + 2], global);

							if (global.Contains("FloatStorage"))
							{
								tensor.Type = Structs.GGmlType.GGML_TYPE_F32;
							}
							else if (global.Contains("HalfStorage"))
							{
								tensor.Type = Structs.GGmlType.GGML_TYPE_F16;
							}
							else if (global.Contains("BFloat16Storage"))
							{
								tensor.Type = Structs.GGmlType.GGML_TYPE_BF16;
							}
							break;
						}
					case 0x86:  // TUPLE2         = b'\x86'  # build 2-tuple from two topmost stack items
						{
							if (binPersid)
							{
								readStrides = true;
							}
							break;
						}
					case 0x85:  // TUPLE1         = b'\x85'  # build 1-tuple from stack top
						if (binPersid)
						{
							readStrides = true;
						}
						break;
					case (byte)'t':   // TUPLE          = b't'   # build tuple from topmost stack items
						deepth--;
						if (binPersid)
						{
							readStrides = true;
						}
						break;
					case (byte)'R': // REDUCE         = b'R'   # apply callable to argtuple, both on stack
						if (deepth == 1)
						{
							if (tensor.Name.Contains("metadata"))
							{
								break;
							}

							if (string.IsNullOrEmpty(tensor.DataNameInZipFile))
							{
								tensor.DataNameInZipFile = tensors.Last().DataNameInZipFile;
								tensor.Offset = new List<ulong> { (ulong)tensor.Shape[0] * Common.GetGGmlTypeSize(tensor.Type) };
								tensor.Shape.RemoveAt(0);
								//tensor.offset = tensors.Last().
							}
							tensors.Add(tensor);

							tensor = new Tensor() { FileName = fileName, Offset = { 0 } };
							readStrides = false;
							binPersid = false;
						}
						break;
					case (byte)'.':  // STOP           = b'.'   # every pickle ends with STOP
						finished = true;
						break;
					default:
						break;
				}
				index++;
			}
			Tensor metaTensor = tensors.Find(x => x.Name.Contains("_metadata"));
			if (metaTensor != null)
			{
				tensors.Remove(metaTensor);
			}
			return tensors;
		}

其中Tensor类的定义如下:

public class Tensor
{
	public string Name { get; set; }
	public Structs.GGmlType Type { get; set; } = Structs.GGmlType.GGML_TYPE_F16;
	public List<long> Shape { get; set; } = new List<long>();
	public List<ulong> Stride { get; set; } = new List<ulong>();
	public string DataNameInZipFile { get; set; }
	public string FileName { get; set; }
	public List<ulong> Offset { get; set; } = new List<ulong>();
	public long BodyPosition { get; set; }

}

因为该代码最初是给C#使用ggml而写,所以tensor的类型使用了ggml中的精度类型,如果有需要可以根据自己的平台修改。

读取tensor的权重值

当获取到tensor的结构后就可以读这一部分了,按照流的方式读取,读取时全部按byte读取。tensor在声明时标识了自己的类型,这会在各个平台计算时自己转化。

		public byte[] ReadByteFromFile(Tensor tensor)
		{
			if (entries is null)
			{
				throw new ArgumentNullException(nameof(entries));
			}

			ZipArchiveEntry dataEntry = entries.First(e => e.Name == tensor.DataNameInZipFile);
			long i = 1;
			foreach (var ne in tensor.Shape)
			{
				i *= ne;
			}
			ulong length = Common.GetGGmlTypeSize(tensor.Type) * (ulong)i;
			byte[] data = new byte[dataEntry.Length];

			using (Stream stream = dataEntry.Open())
			{
				stream.Read(data, 0, data.Length);
			}

			//data = data.Take(new Range((int)tensor.Offset[0], (int)(tensor.Offset[0] + length))).ToArray();
			byte[] result = new byte[length];
			for (int j = 0; j < (int)length; j++)
			{
				result[j] = data[j + (int)tensor.Offset[0]];
			}
			return result;
			//return data;
		}

该方法的不足

该方法虽然能够读取一些Pickle类型的权重文件,但并不是所有。目前只支持文件中仅有权重的类型,若data.pkl文件中还包含了图的信息,目前该方法还不能正常使用。若有兴趣可以帮助我完善PickleLoader这个方法,一起为C#在深度学习领域添砖加瓦。

总结

C#读取Pickle类型的权重文件较Safetensors文件更加困难,其原因为在C#下没有很好的反序列化data.pkl文件的方式,只能按照文件定义顺序读取。目前使用C#搞深度学习的人并不多,相关功能实现并不普及。撰写本文是希望能够帮助更多喜欢使用C#开发深度学习项目的爱好者更容易实现自己的项目。
该项目的完整代码可以从C#直接读取Pickle类型权重下载。

该模块来自我正在开发的GGMLSharp项目,如果喜欢该项目,请在GitHub上送我一颗小星星。

  • 16
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值