torch与numpy转换内存上的小坑

在学习机器学习中,tensor的各种转换是新手容易遇到的坑 ,我这里记录一下我遇到的一些坑

  1. 将numpy数据类型转换成Tensor

a = torch.ones(5)
b = a.numpy()
a.add_(1)  # 就地版本的add()
print(a)
print(b)
tensor([2., 2., 2., 2., 2.])
[2. 2. 2. 2. 2.]

torch中的add_()是就地版本的add(),这样b的值会随a变化,而若使用add() 则b的值是全1

  1. 将numpy数组转化成Torch的Tensor

import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)  
np.add(a,1,out=a)
print(a)
print(b)
[2. 2. 2. 2. 2.]
tensor([2., 2., 2., 2., 2.], dtype=torch.float64)

torch.from_numpy()会创建一个tensor从numpy转化来的,返回的tensor和之前的narray共享内存,下面是官方的解释:

Creates a Tensor from a numpy.ndarray.

The returned tensor and ndarray share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. The returned tensor is not resizable.

  1. 复制tensor数据类型的数据时候

x = torch.arange(12)
y = torch.tensor(x)

这样复制时会报UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). y = torch.tensor(x)的警告

x = torch.arange(12)
y = torch.as_tensor(x)

使用as_tensor做复制可以不报错,是官方推荐的写法

  1. torch的reshape()是返回的一个view

a = torch.arange(12)
b = a.reshape((3,4))
b[:] = 2
a
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值