[Pytorch] torch.save原理

torch.save调用serialization.py里的_save:

def _save(obj, zip_file, pickle_module, pickle_protocol):
    serialized_storages = {}
    id_map: Dict[int, str] = {}

    def persistent_id(obj):
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            obj_key = id_map.setdefault(obj._cdata, str(len(id_map)))
            location = location_tag(obj)
            serialized_storages[obj_key] = obj

            return ('storage',
                    storage_type,
                    obj_key,
                    location,
                    obj.size())
        return None

    # Write the pickle data for `obj`
    data_buf = io.BytesIO()
    pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
    pickler.persistent_id = persistent_id
    pickler.dump(obj)
    data_value = data_buf.getvalue()
    zip_file.write_record('data.pkl', data_value, len(data_value))

    # Write each tensor to a file named tensor/the_tensor_key in the zip archive
    for key in sorted(serialized_storages.keys()):
        name = f'data/{key}'
        storage = serialized_storages[key]
        # given that we copy things around anyway, we might use storage.cpu()
        # this means to that to get tensors serialized, you need to implement
        # .cpu() on the underlying Storage
        if storage.device.type != 'cpu':
            storage = storage.cpu()
        # Now that it is on the CPU we can directly copy it into the zip file
        num_bytes = storage.size() * storage.element_size()
        zip_file.write_record(name, storage.data_ptr(), num_bytes)

persistent_id把Tensor里的Storage先放到serialized_storages里面;

pickler会把除Storage之外的Tensor字段,都serialize到data_value里面,写入文件;

serialized_storages里面的“内存块”们,单独写入文件;

这里的"storage.data_ptr()"等,调用的是torch\csrc\generic\StorageMethods.cpp里的THPStorage_(dataPtr)

static PyObject * THPStorage_(dataPtr)(PyObject *_self, PyObject *noargs)
{
  HANDLE_TH_ERRORS
  auto self = (THPStorage*)_self;
  return PyLong_FromVoidPtr(THWStorage_(data)(LIBRARY_STATE self->cdata));
  END_HANDLE_TH_ERRORS
}

THPStorage定义:torch\csrc\StorageDefs.h

struct THPStorage {
  PyObject_HEAD
  THWStorage *cdata;
};

THWStorage定义:torch\csrc\THP.h

#define THWStorage THStorage

THStorage定义:aten\src\TH\generic\THStorage.h

#define THStorage at::StorageImpl

所以,self->cdata就是at::StorageImpl对象;storage.data_ptr()调用的是THWStorage_(data)(<pStorageImpl>)

其实就是THStorage_(data)(<pStorageImpl>),定义在aten\src\TH\generic\THStorage.cpp

scalar_t* THStorage_(data)(const THStorage *self)
{
#if defined(THQUANTIZED)
  return reinterpret_cast<scalar_t*>(self->data<quantized_t>());
#else
  return self->data<scalar_t>();
#endif
}

所以,调用的是at::StorageImpl的data()函数,定义在c10\core\StorageImpl.h

  template <typename T>
  inline T* data() const {
    return unsafe_data<T>();
  }

  template <typename T>
  inline T* unsafe_data() const {
    return static_cast<T*>(this->data_ptr_.get());
  }

Reference: (50条消息) [Pytorch] Tensor底层机制_smartcat2010的博客-CSDN博客

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值