如何使用C#读取pickle类型的大模型文件
背景
前几天发布了如何使用C#读取safetensors扩展名的大模型文件的文章,介绍了C#加载Safetensors文件的方法。在深度学习领域,除了Safetensors类型的权重文件外,Pickle类型的权重文件也十分流行,这种类型的权重文件一般以.pt,.pth,.pkl,.bin为扩展名。Pickle类型的文件大多是Python开发环境生成的,在Python下读写十分容易,但是除Python环境下直接加载还是十分困难的。尤其在C#环境下,还没有广泛使用/通用的pickle类型权重文件的读取工具。一些开发者会像处理safetensors文件一样,将其转换为onnx再开发,但是这样会需要一步额外转化,而且转化时也没脱离python开发环境,所需要的依赖并不少。
为了解决C#下使用Pickle类型文件的问题,故开发了相关功能。本文介绍了如何使用C#直接读取Pickle类型文件的内容及其各项的权重,以方便在C#环境下使用各种不同的深度学习框架,例如Tochsharp、GGMLSharp等,经过简单的处理就可以直接加载Pickle文件权重。
Pickle文件的结构及读取思路
Pickle文件可以认为是一种压缩文件,使用压缩工具可以直接打开。
我们可以使用压缩工具先来观察一下Pickle文件的结构
从上图可以看到data.pkl类型的文件有权重信息的文件,version文件包含了版本信息,data文件夹保存各个Tensor数据。我们的重点就在如何解读data.pkl文件上。
- 读取data.pkl文件,并以二进制的方式对其进行解析;
- 可以建立一个新的类,用于存储读取到的tensor的信息;
- 读取data文件夹内的tensor的值(data),并加载进tensor中;
- 为了能够节省内存/显存,提高tensors结构体的读取速度,可以先只读取tensors的结构,在使用tensor的数据时,才进行读取其值;
读取方法
使用C#以压缩文件的方式对Pickle文件整体读取,为了减少内存的使用,提高读取速度,建议先读取结构,在需要使用tensor的值时再读取其数据。
以压缩文件的方式加载Pickle类型文件
可以使用.net自带的ZipArchive方法加载压缩文件。
private ZipArchive zip;
private ReadOnlyCollection<ZipArchiveEntry> entries;
zip = ZipFile.OpenRead(fileName);
entries = zip.Entries;
读取Header的内容
Header的内容可以看作是一个二进制文件,里面有名称、数据类型、在文件中的偏移量、形状等信息。建立一个新的类来进行读取和存储。
例如一份data.pkl读取后的结果如下,以
6: X BINUNICODE 'model.layers.24.self_attn.q_proj.weight'
为例进行解释:
6代表该文件的第6个字节,X代表该字节对应的内容,该字节内容若以ASCII码标识就是’X’这个字符(注意区分大小写),X在Pickle文件中代表接下来的内容时Binunicode(字符类型),其内容是’model.layers.24.self_attn.q_proj.weight’
再下一行
50: q BINPUT 1
代表下一个完整内容从第50个byte开始了
0: \x80 PROTO 2
2: } EMPTY_DICT
3: q BINPUT 0
5: ( MARK
6: X BINUNICODE 'model.layers.24.self_attn.q_proj.weight'
50: q BINPUT 1
52: c GLOBAL 'torch._utils _rebuild_tensor_v2'
85: q BINPUT 2
87: ( MARK
88: ( MARK
89: X BINUNICODE 'storage'
101: q BINPUT 3
103: c GLOBAL 'torch HalfStorage'
122: q BINPUT 4
124: X BINUNICODE '0'
130: q BINPUT 5
132: X BINUNICODE 'cpu'
140: q BINPUT 6
142: J BININT 16777216
147: t TUPLE (MARK at 88)
148: q BINPUT 7
150: Q BINPERSID
151: K BININT1 0
153: M BININT2 4096
156: M BININT2 4096
159: \x86 TUPLE2
160: q BINPUT 8
162: M BININT2 4096
165: K BININT1 1
167: \x86 TUPLE2
168: q BINPUT 9
170: \x89 NEWFALSE
171: c GLOBAL 'collections OrderedDict'
196: q BINPUT 10
198: ) EMPTY_TUPLE
199: R REDUCE
200: q BINPUT 11
202: t TUPLE (MARK at 87)
203: q BINPUT 12
205: R REDUCE
206: q BINPUT 13
208: X BINUNICODE 'model.layers.24.self_attn.k_proj.weight'
252: q BINPUT 14
254: h BINGET 2
256: ( MARK
257: ( MARK
258: h BINGET 3
260: h BINGET 4
262: X BINUNICODE '1'
268: q BINPUT 15
270: h BINGET 6
272: J BININT 16777216
277: t TUPLE (MARK at 257)
278: q BINPUT 16
280: Q BINPERSID
281: K BININT1 0
283: M BININT2 4096
286: M BININT2 4096
289: \x86 TUPLE2
290: q BINPUT 17
292: M BININT2 4096
295: K BININT1 1
297: \x86 TUPLE2
298: q BINPUT 18
300: \x89 NEWFALSE
301: h BINGET 10
303: ) EMPTY_TUPLE
304: R REDUCE
305: q BINPUT 19
307: t TUPLE (MARK at 256)
308: q BINPUT 20
310: R REDUCE
311: q BINPUT 21
313: X BINUNICODE 'model.layers.24.self_attn.v_proj.weight'
......
10115: r LONG_BINPUT 685
10120: R REDUCE
10121: r LONG_BINPUT 686
10126: X BINUNICODE 'lm_head.weight'
10145: r LONG_BINPUT 687
10150: h BINGET 2
10152: ( MARK
10153: ( MARK
10154: h BINGET 3
10156: h BINGET 4
10158: X BINUNICODE '85'
10165: r LONG_BINPUT 688
10170: h BINGET 6
10172: J BININT 131072000
10177: t TUPLE (MARK at 10153)
10178: r LONG_BINPUT 689
10183: Q BINPERSID
10184: K BININT1 0
10186: M BININT2 32000
10189: M BININT2 4096
10192: \x86 TUPLE2
10193: r LONG_BINPUT 690
10198: M BININT2 4096
10201: K BININT1 1
10203: \x86 TUPLE2
10204: r LONG_BINPUT 691
10209: \x89 NEWFALSE
10210: h BINGET 10
10212: ) EMPTY_TUPLE
10213: R REDUCE
10214: r LONG_BINPUT 692
10219: t TUPLE (MARK at 10152)
10220: r LONG_BINPUT 693
10225: R REDUCE
10226: r LONG_BINPUT 694
10231: u SETITEMS (MARK at 5)
10232: . STOP
对data.pkl文件读取可以按byte顺序进行,一直读到该文件结束。读取的代码如下:
public List<Tensor> ReadTensorsInfoFromFile(string fileName)
{
List<Tensor> tensors = new List<Tensor>();
zip = ZipFile.OpenRead(fileName);
entries = zip.Entries;
ZipArchiveEntry headerEntry = entries.First(e => e.Name == "data.pkl");
byte[] headerBytes = new byte[headerEntry.Length];
// Header is always small enough to fit in memory, so we can read it all at once
using (Stream stream = headerEntry.Open())
{
stream.Read(headerBytes, 0, headerBytes.Length);
}
if (headerBytes[0] != 0x80 || headerBytes[1] != 0x02)
{
throw new ArgumentException("Not a valid pickle file");
}
int index = 1;
bool finished = false;
bool readStrides = false;
bool binPersid = false;
Tensor tensor = new Tensor() { FileName = fileName, Offset = { 0 } };
int deepth = 0;
Dictionary<int, string> BinPut = new Dictionary<int, string>();
while (index < headerBytes.Length && !finished)
{
byte opcode = headerBytes[index];
switch (opcode)
{
case (byte)'}': // EMPTY_DICT = b'}' # push empty dict
break;
case (byte)']': // EMPTY_LIST = b']' # push empty list
break;
// skip unused sections
case (byte)'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg
{
int id = headerBytes[index + 1];
BinPut.TryGetValue(id, out string precision);
if (precision != null)
{
if (precision.Contains("FloatStorage"))
{
tensor.Type = Structs.GGmlType.GGML_TYPE_F32;
}
else if (precision.Contains("HalfStorage"))
{
tensor.Type = Structs.GGmlType.GGML_TYPE_F16;
}
else if (precision.Contains("BFloat16Storage"))
{
tensor.Type = Structs.GGmlType.GGML_TYPE_BF16;
}
}
index++;
break;
}
case (byte)'q': // BINPUT = b'q' # " " " " " ; " " 1-byte arg
{
index++;
break;
}
case (byte)'Q': // BINPERSID = b'Q' # " " " ; " " " " stack
binPersid = true;
break;
case (byte)'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg
index += 4;
break;
case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame
index += 8;
break;
case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo
break;
case (byte)'(': // MARK = b'(' # push special markobject on stack
deepth++;
break;
case (byte)'K': // BININT1 = b'K' # push 1-byte unsigned int
{
int value = headerBytes[index + 1];
index++;
if (deepth > 1 && value != 0 && binPersid)
{
if (readStrides)
{
//tensor.Stride.Add((ulong)value);
tensor.Stride.Add((ulong)value);
}
else
{
tensor.Shape.Add(value);
}
}
}
break;
case (byte)'M': // BININT2 = b'M' # push 2-byte unsigned int
{
UInt16 value = BitConverter.ToUInt16(headerBytes, index + 1);
index += 2;
if (deepth > 1 && value != 0 && binPersid)
{
if (readStrides)
{
tensor.Stride.Add(value);
}
else
{
tensor.Shape.Add(value);
}
}
}
break;
case (byte)'J': // BININT = b'J' # push four-byte signed int
{
int value = BitConverter.ToInt32(headerBytes, index + 1);
//int value = headerBytes[index + 4] << 24 + headerBytes[index + 3] << 16 + headerBytes[index + 2] << 8 + headerBytes[index + 1];
index += 4;
if (deepth > 1 && value != 0 && binPersid)
{
if (readStrides)
{
tensor.Stride.Add((ulong)value);
}
else
{
tensor.Shape.Add(value);
}
}
}
break;
case (byte)'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument
{
int length = headerBytes[index + 1];
int start = index + 5;
byte module = headerBytes[index + 1];
string name = System.Text.Encoding.UTF8.GetString(headerBytes, start, length);
index = index + 4 + length;
if (deepth == 1)
{
tensor.Name = name;
}
else if (deepth == 3)
{
if ("cpu" != name && !name.Contains("cuda"))
{
tensor.DataNameInZipFile = name;
}
}
}
break;
case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes
{
}
break;
case (byte)'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args
{
int start = index + 1;
while (headerBytes[index + 1] != (byte)'q')
{
index++;
}
int length = index - start + 1;
string global = System.Text.Encoding.UTF8.GetString(headerBytes, start, length);
// precision is stored in the global variable
// next tensor will read the precision
// so we can set the Type here
BinPut.Add(headerBytes[index + 2], global);
if (global.Contains("FloatStorage"))
{
tensor.Type = Structs.GGmlType.GGML_TYPE_F32;
}
else if (global.Contains("HalfStorage"))
{
tensor.Type = Structs.GGmlType.GGML_TYPE_F16;
}
else if (global.Contains("BFloat16Storage"))
{
tensor.Type = Structs.GGmlType.GGML_TYPE_BF16;
}
break;
}
case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items
{
if (binPersid)
{
readStrides = true;
}
break;
}
case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top
if (binPersid)
{
readStrides = true;
}
break;
case (byte)'t': // TUPLE = b't' # build tuple from topmost stack items
deepth--;
if (binPersid)
{
readStrides = true;
}
break;
case (byte)'R': // REDUCE = b'R' # apply callable to argtuple, both on stack
if (deepth == 1)
{
if (tensor.Name.Contains("metadata"))
{
break;
}
if (string.IsNullOrEmpty(tensor.DataNameInZipFile))
{
tensor.DataNameInZipFile = tensors.Last().DataNameInZipFile;
tensor.Offset = new List<ulong> { (ulong)tensor.Shape[0] * Common.GetGGmlTypeSize(tensor.Type) };
tensor.Shape.RemoveAt(0);
//tensor.offset = tensors.Last().
}
tensors.Add(tensor);
tensor = new Tensor() { FileName = fileName, Offset = { 0 } };
readStrides = false;
binPersid = false;
}
break;
case (byte)'.': // STOP = b'.' # every pickle ends with STOP
finished = true;
break;
default:
break;
}
index++;
}
Tensor metaTensor = tensors.Find(x => x.Name.Contains("_metadata"));
if (metaTensor != null)
{
tensors.Remove(metaTensor);
}
return tensors;
}
其中Tensor类的定义如下:
public class Tensor
{
public string Name { get; set; }
public Structs.GGmlType Type { get; set; } = Structs.GGmlType.GGML_TYPE_F16;
public List<long> Shape { get; set; } = new List<long>();
public List<ulong> Stride { get; set; } = new List<ulong>();
public string DataNameInZipFile { get; set; }
public string FileName { get; set; }
public List<ulong> Offset { get; set; } = new List<ulong>();
public long BodyPosition { get; set; }
}
因为该代码最初是给C#使用ggml而写,所以tensor的类型使用了ggml中的精度类型,如果有需要可以根据自己的平台修改。
读取tensor的权重值
当获取到tensor的结构后就可以读这一部分了,按照流的方式读取,读取时全部按byte读取。tensor在声明时标识了自己的类型,这会在各个平台计算时自己转化。
public byte[] ReadByteFromFile(Tensor tensor)
{
if (entries is null)
{
throw new ArgumentNullException(nameof(entries));
}
ZipArchiveEntry dataEntry = entries.First(e => e.Name == tensor.DataNameInZipFile);
long i = 1;
foreach (var ne in tensor.Shape)
{
i *= ne;
}
ulong length = Common.GetGGmlTypeSize(tensor.Type) * (ulong)i;
byte[] data = new byte[dataEntry.Length];
using (Stream stream = dataEntry.Open())
{
stream.Read(data, 0, data.Length);
}
//data = data.Take(new Range((int)tensor.Offset[0], (int)(tensor.Offset[0] + length))).ToArray();
byte[] result = new byte[length];
for (int j = 0; j < (int)length; j++)
{
result[j] = data[j + (int)tensor.Offset[0]];
}
return result;
//return data;
}
该方法的不足
该方法虽然能够读取一些Pickle类型的权重文件,但并不是所有。目前只支持文件中仅有权重的类型,若data.pkl文件中还包含了图的信息,目前该方法还不能正常使用。若有兴趣可以帮助我完善PickleLoader这个方法,一起为C#在深度学习领域添砖加瓦。
总结
C#读取Pickle类型的权重文件较Safetensors文件更加困难,其原因为在C#下没有很好的反序列化data.pkl文件的方式,只能按照文件定义顺序读取。目前使用C#搞深度学习的人并不多,相关功能实现并不普及。撰写本文是希望能够帮助更多喜欢使用C#开发深度学习项目的爱好者更容易实现自己的项目。
该项目的完整代码可以从C#直接读取Pickle类型权重下载。
该模块来自我正在开发的GGMLSharp项目,如果喜欢该项目,请在GitHub上送我一颗小星星。