Pytorch 基本的数据类型

    Pytorch中的基本数据类型就是各式各样的张量,张量可以理解为多维矩阵。Pytorch中定义了一个Tensor类来实现张量,Tensor在使用上与numpy的ndarray类似,不同的是,Tensor可以在GPU上运行,但是numpy只能在CPU上运行,当然numpy与Tensor可以进行相互转换,以此使得numpy数据在GPU上运行。Pytorch中的Tensor又包括CPU上的数据类型和GPU上的数据类型,两种数据类型之间也可以进行相互转换,下面是Pytorch定义的数据类型:

1. 类型检查

Pytorch数据类型的检查可以通过三个方式:

1)python内置函数type()

2)Tensor的成员函数Tensor.type()

3)Pytorch提供的工具函数isinstance()

示例代码:

import torch

a = torch.randn(2, 3)  # 2行3列,正态分布~N(0,1)
print(a)
print(type(a))
print(a.type())
print(isinstance(a, torch.FloatTensor))

运行结果:

tensor([[-0.6646,  0.3935,  1.2683],
        [-1.8576,  0.2761,  1.4787]])
<class 'torch.Tensor'>
torch.FloatTensor
True

从运行结果来看,python内置的类型检测函数type()只能检查该数据是Tensort类型,具体的基本数据类型无法检测出来,而Tensor的成员函数type()更加直观,可以检测出Tensor数据的基本类型。isinstance()函数主要用于判断某数据是否属于某个数据类型,如果属于返回True,否则返回False。

2. 数据类型转换

Tensor类型的变量进行类型转换一般有两种方法:

1)Tensor类型的变量直接调用long(), int(), double(),float(),byte()等函数就能将Tensor进行类型转换;

2)在Tensor成员函数type()中直接传入要转换的数据类型。

当你不知道要转换为什么类型时,但需要求a1,a2两个张量的乘积,可以使用a1.type_as(a2)将a1转换为a2同类型。

示例代码:

import torch

a = torch.randn(2, 3)
print(a.type())

# 转换为IntTensort类型
b = a.int()

# 转换为LongTensor类型
c = a.type(torch.LongTensor)

print(b.type())
print(c.type())

# 将a转换为与b相同的类型
a.type_as(b)
print(a.type())

运行结果:

torch.FloatTensor
torch.IntTensor
torch.LongTensor
torch.FloatTensor

3. Tensor与Numpy ndarray之间的转换

Tensor和numpy.ndarray之间还可以相互转换,其方式如下:

1)Numpy转化为Tensor:torch.from_numpy(numpy矩阵)

2)Tensor转化为numpy:Tensor矩阵.numpy()

示例代码:

import torch
import numpy as np

# 定义一个3行2列的全为0的矩阵
b = torch.randn((3, 2))

# tensor转化为numpy
numpy_b = b.numpy()
print(numpy_b)

# numpy转化为tensor
numpy_e = np.array([[1, 2], [3, 4], [5, 6]])
torch_e = torch.from_numpy(numpy_e)

print(numpy_e)
print(torch_e)

4. CPU或GPU张量之间的转

1) CPU张量 ---->  GPU张量, 使用Tensor.cuda()

2) GPU张量 ----> CPU张量 使用Tensor.cpu()

我们可以通过torch.cuda.is_available()函数来判断当前的环境是否支持GPU,如果支持,则返回True。所以,为保险起见,在项目代码中一般采取“先判断,后使用”的策略来保证代码的正常运行,其基本结构如下:

import torch

# 定义一个3行2列的全为0的矩阵
tmp = torch.randn((3, 2))

# 如果支持GPU,则定义为GPU类型
if torch.cuda.is_available():
    inputs = tmp.cuda()
# 否则,定义为一般的Tensor类型
else:
    inputs = tmp


 

 

  • 8
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

洪流之源

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值