Tensor类型
PyTorch 定义的 tensor 类型(默认 tensor 类型为 FloatTensor):
不同 tensor 类型所占用的内存是有差别的,这也可能是你 oom 的原因之一,例如模型中对 tensor 类型做了转换,导致了原本可装载的 tensor 爆了显存。
一般地,在图像分类等项目中,FloatTensor 和 LongTensor 是最常用的两个 tensor 类型。
那么,PyTorch 中的 tensor 和 Numpy 中的 ndarray 又有什么区别呢?其实主要区别在于 tensor 可以部署在 GPU 和 CPU (大佬也可能是 TPU )中去进行计算,而 ndarray 只能在 CPU 中进行计算。
PyTorch 的 tensor 和 Numpy 的 ndarray 如何相互转换?
函数 | 功能 |
---|---|
tensor.numpy() | 将 Tensor 转变为 ndarray |
torch.from_numpy(ndarray) | 将 ndarray 类型转变为 Tensor |