PyTorch Tensor创建方式、cpu/gpu/数据类型和设备指定、设定默认数据类型和设备

PyTorch Tensor创建、cpu/gpu/数据类型指定

无论如何,请多注意数据类型和设备指定,避免不必要的错误。

torch的默认数据类型和设备

  1. 默认的整数类型是torch.int64,默认的浮点类型是torch.float32,默认设备为cpu
  2. 创建时可以通过dtype参数来指定类型,如torch.tensor(1., dtype=torch.float64)
  3. 创建时可以指定设备,如torch.arange(12, device="cuda")

torch设置默认数据类型和设备

  1. 通过torch.set_default_tensor_type方法指定数据类型和设备
  2. 该方法的参数是torch的Tensor类型,这里类型包含了整数、浮点、设备
  3. 只能将浮点类型指定为默认类型
## 指定数据存放在gpu上,默认类型是float64
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
t = torch.tensor(1.)
print(t.device, t.dtype, sep='\n')
# cuda:0
# torch.float64

下面列出float的相关类型说明,查看完整说明请看官方文档
在这里插入图片描述

Tensor创建的多种方法

torch.tensor

这是最直接的方式,但只适合简单张量的创建

t = torch.tensor([1, 2]) # torch.int64
t = torch.tensor([1, 2], dtype=torch.float16) # torch.float16

torch.arange和torch.linspace

与numpy相似

  1. arange(start=0, end, step=1)
  2. linspace(start, end, steps=100)
t = torch.arange(12)
# tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
t = torch.arange(2, 12, 2)
# tensor([2, 4, 6, 8, 10])
t = torch.linspace(1, 2, 7)
# tensor([1.0000, 1.1667, 1.3333, 1.5000, 1.6667, 1.8333, 2.0000])

注意这里也可以通过dtype指定为浮点数

torch.zeros、torch.ones、torch.randn

需要注意的是与numpy区分

  1. numpy的必须参数是shape类型数据,必须用括号括起来
  2. torch的必须参数是*size类型数据,这里可参考函数传递可变参数理解
torch.zeros(2, 2)
numpy.zeros((2, 2))

torch.as_tensor

一个强大的转换工具,可以将多种数据类型转换为tensor,例如

aArray = numpy.array([3, 4, 5])
aList = [1, 2, 3]
t1 = torch.as_tensor(aList)
t2 = torch.as_tensor(n)

torch.as_strided

从一个存在的Tensor中通过步幅来得到新的Tensor,下面举个例子说明

t1 = torch.randn(3, 3)
print(t1)
t2 = torch.as_strided(t1, (2, 2), (1, 2))
t2

上面代码段的输出为
tensor([[ 0.5368, 1.0748, -1.5365],
    [-0.1329, 0.2131, -1.9862],
    [-0.1880, -1.1600, 1.3971]])
tensor([[ 0.5368, -1.5365],
    [ 1.0748, -0.1329]])

此外该方法还可以指定初始偏移量(原文档说是内存偏移量,我理解就是首个元素的偏移量,可能不对)
t2 = torch.as_strided(t1, (1, 2), (1, 2), 1)


关于torch.as_strided方法的理解,推荐移步here,其内在就是存储的偏移量,本文不再赘述。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值