torch.save调用serialization.py里的_save:
def _save(obj, zip_file, pickle_module, pickle_protocol):
serialized_storages = {}
id_map: Dict[int, str] = {}
def persistent_id(obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if torch.is_storage(obj):
storage_type = normalize_storage_type(type(obj))
obj_key = id_map.setdefault(obj._cdata, str(len(id_map)))
location = location_tag(obj)
serialized_storages[obj_key] = obj
return ('storage',
storage_type,
obj_key,
location,
obj.size())
return None
# Write the pickle data for `obj`
data_buf = io.BytesIO()
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()
zip_file.write_record('data.pkl', data_value, len(data_value))
# Write each tensor to a file named tensor/the_tensor_key in the zip archive
for key in sorted(serialized_storages.keys()):
name = f'data/{key}'
storage = serialized_storages[key]
# given that we copy things around anyway, we might use storage.cpu()
# this means to that to get tensors serialized, you need to implement
# .cpu() on the underlying Storage
if storage.device.type != 'cpu':
storage = storage.cpu()
# Now that it is on the CPU we can directly copy it into the zip file
num_bytes = storage.size() * storage.element_size()
zip_file.write_record(name, storage.data_ptr(), num_bytes)
persistent_id把Tensor里的Storage先放到serialized_storages里面;
pickler会把除Storage之外的Tensor字段,都serialize到data_value里面,写入文件;
serialized_storages里面的“内存块”们,单独写入文件;
这里的"storage.data_ptr()"等,调用的是torch\csrc\generic\StorageMethods.cpp里的THPStorage_(dataPtr)
static PyObject * THPStorage_(dataPtr)(PyObject *_self, PyObject *noargs)
{
HANDLE_TH_ERRORS
auto self = (THPStorage*)_self;
return PyLong_FromVoidPtr(THWStorage_(data)(LIBRARY_STATE self->cdata));
END_HANDLE_TH_ERRORS
}
THPStorage定义:torch\csrc\StorageDefs.h
struct THPStorage {
PyObject_HEAD
THWStorage *cdata;
};
THWStorage定义:torch\csrc\THP.h
#define THWStorage THStorage
THStorage定义:aten\src\TH\generic\THStorage.h
#define THStorage at::StorageImpl
所以,self->cdata就是at::StorageImpl对象;storage.data_ptr()调用的是THWStorage_(data)(<pStorageImpl>)
其实就是THStorage_(data)(<pStorageImpl>),定义在aten\src\TH\generic\THStorage.cpp
scalar_t* THStorage_(data)(const THStorage *self)
{
#if defined(THQUANTIZED)
return reinterpret_cast<scalar_t*>(self->data<quantized_t>());
#else
return self->data<scalar_t>();
#endif
}
所以,调用的是at::StorageImpl的data()函数,定义在c10\core\StorageImpl.h
template <typename T>
inline T* data() const {
return unsafe_data<T>();
}
template <typename T>
inline T* unsafe_data() const {
return static_cast<T*>(this->data_ptr_.get());
}
Reference: (50条消息) [Pytorch] Tensor底层机制_smartcat2010的博客-CSDN博客