一. 数据输入的类型
pytorch的基本数据结构是张量Tensor
1.张量的数据类型
张量的数据类型和numpy.array基本一一对应,但是不支持str类型。
包括:
torch.float16
torch.float32(torch.float)
torch.float64(torch.double)
torch.int8
torch.uint8
torch.int16
torch.int32(torch.int)
torch.int64(torch.long)
torch.bool
一般的神经网络建模使用的都是torch.float32类型
2. 张量的维度和尺寸
常用的方法
#查看维度
dim()
#查看形状尺寸
size()
shape
#改变尺寸
reshape()
view()
3.张量、numpy数组、list的相互转化
张量转化为numpy数组, 借助numpy()方法
numpy数组转化为张量,借助torch.from_numpy()
注意上面两种方法是共享内存的,一个改变另一个也会改变。
可以用张量的clone()方法来中断这种联系,tensor.data.numpy()也可以