原文:https://mp.weixin.qq.com/s/TStrHMbgDjPIsXaPK6-VSQ
1 dill 简介
pickle/dill
可以用于保存对象等大多数Python的数据格式; 其中pickle不可以保存lambda
函数,序列化对象等,但dill
可以保存。
pickle 和 dill 的用法一样。
dill 特点:
1) 可以pickle以下标准类型
- none, type, bool, int, long, float, complex, str, unicode,
- tuple, list, dict, file, buffer, builtin,
- both old and new style classes,
- instances of old and new style classes,
- set, frozenset, array, functions, exceptions
2) 也可以pickle一些独特的类型
- functions with yields, nested functions, lambdas,
- cell, method, unboundmethod, module, code, methodwrapper,
- dictproxy, methoddescriptor, getsetdescriptor, memberdescriptor,
- wrapperdescriptor, xrange, slice,
- notimplemented, ellipsis, quit
3) 但以下类型暂时不可以pickle:
- frame, generator, traceback
4) dill的其他作用
- save and load python interpreter sessions
- save and extract the source code from functions and classes
- interactively diagnose pickling errors
接下来,介绍 dill 的几种用法。
2 保存匿名函数
# !pip install dill
import dill
# 保存匿名函数
squared = lambda x: x**2
dill.loads(dill.dumps(squared))(3)
# 9
3 查看源码
# 保存源码
import dill.source
print(dill.source.getsource(squared))
在 ipython会报OSError: could not extract source code
, 改成如下形式即可:
code=dill.source.getsource(dill.detect.code(squared))
print(code)
# squared = lambda x: x**2
4 保存类Dataset, DataLoader
这样可以加速数据预处理。
# 保存class
from torch.utils.data import TensorDataset, DataLoader
import torch
from sklearn.datasets import make_classification
data, target = make_classification()
# data.shape, target.shape
# ((100, 20), (100,))
batch_size=10
dataset = TensorDataset(torch.from_numpy(data), torch.from_numpy(target))
dataloader = DataLoader(dataset, shuffle=False, drop_last=True, batch_size=batch_size)
保存数据:
dill.dump(dataset, './dataset_save.pkl')
dill.dump(dataloader, './dataloader_save.pkl')
直接保存会报错:TypeError: file must have a 'write' attribute
采用下述方式保存即可:
with open('./dataset_save.pkl','wb') as f:
dill.dump(dataset, f)
with open('./dataloader_save.pkl','wb') as f:
dill.dump(dataloader, f)
加载数据:
with open('./dataset_save.pkl','rb') as f:
dataset_save = dill.load(f)
with open('./dataloader_save.pkl','rb') as f:
dataloader_save = dill.load(f)
数据比较:
x, y = next(iter(dataloader))
x_save, y_save = next(iter(dataloader_save))
torch.equal(x, x_save), torch.equal(y, y_save)
# (True, True)