Pytorch中tensor与ndarray类型转换及标量转换

tensor与ndarrary的转换

Pytorch中的tensor与ndarray在底层数据类型设计有相似之处,在Pytorch框架中tensor与ndarray可以较为方便地转换

tensor转ndarray

tensor转ndarray分为浅拷贝与深拷贝

浅拷贝

浅拷贝一般使用numpy()方法

import torch
import numpy as np

data1 = torch.tensor([1, 2, 3])
print(data1)
data2 = data1.numpy()
print(data2)
data1[0] = 9
print(data1)
print(data2)
# tensor([1, 2, 3])
# [1 2 3]
# tensor([9, 2, 3])
# [9 2 3]

可以看到,在对转换成ndarray类型的data2进行修改后,tensor的值也随之改变,这是因为二者底层共用一块,为浅拷贝

深拷贝

深拷贝我们可以对tensor进行clone()后再进行转换,clone()会拷贝一份完全独立的张量,并会拷贝计算图

import torch
import numpy as np

data1 = torch.tensor([1, 2, 3])
print(data1)
data2 = data1.clone().numpy()
print(data2)
data1[0] = 9
print(data1)
print(data2)
# tensor([1, 2, 3])
# [1 2 3]
# tensor([9, 2, 3])
# [1 2 3]

可以看到这里在对张量进行修改后,并不会影响ndarray,因为这里为深拷贝

ndarray转tensor

ndarray转tensor同样分为深拷贝和浅拷贝

浅拷贝

浅拷贝一般是通过torch.from_numpy()实现的

import torch
import numpy as np

data1 = np.array([1, 2, 3])
data2 = torch.from_numpy(data1)
print(data1)
print(data2)
data1[0] = 9
print(data1)
print(data2)
# [1 2 3]
# tensor([1, 2, 3], dtype=torch.int32)
# [9 2 3]
# tensor([9, 2, 3], dtype=torch.int32)

可以看到浅拷贝后,对共享内存的任意一个对象修改都会影响到另一个的值

深拷贝

深拷贝这里我们可以通过对ndarray进行copy()进行深拷贝创立副本

import torch
import numpy as np

data1 = np.array([1, 2, 3])
data2 = torch.from_numpy(data1.copy())
print(data1)
print(data2)
data1[0] = 9
print(data1)
print(data2)
# [1 2 3]
# tensor([1, 2, 3], dtype=torch.int32)
# [9 2 3]
# tensor([1, 2, 3], dtype=torch.int32)

张量提取标量

tensor可以分为矢量张量和标量张量,对于从张量中提取标量值一般可以使用item()方法,要求tensor为单个元素才可以使用

import torch
import numpy as np

data1 = torch.tensor(1)
data2 = torch.tensor([1])
print(data1)
print(data2)
print(data1.item())
print(data2.item())
# tensor(1)
# tensor([1])
# 1
# 1
import torch
import numpy as np

data1 = torch.tensor([1, 2, 3])
print(data1)

print(data1.item())
tensor([1, 2, 3])
# Traceback (most recent call last):
#   File "D:\Pythonproject\teach_day_01\demo02.py", line 7, in <module>
#     print(data1.item())
#           ^^^^^^^^^^^^
# RuntimeError: a Tensor with 3 elements cannot be converted to Scalar

可以看到非标量张量无法进行item()标量值提取

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值