解决:TypeError: can‘t convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor

项目场景:

在测试pytorch训练出的模型的时候,遇到如下报错:

Traceback (most recent call last):
  File "D:\PycharmProjects\RL_TSP_4static-master\Post_process\load_all_reward.py", line 81, in <module>
    objs[i,:] = [obj1, obj2]
  File "D:\Anaconda\envs\DRL\lib\site-packages\torch\_tensor.py", line 1032, in __array__
    return self.numpy().astype(dtype, copy=False)
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

问题描述

这个报错说的是无法将tensor类型转化为numpy(objs是一个numpy数组,obj1和obj2分别都是tensor对象)。首先,tensor,中文叫做张量,它是Pytorch中最基本的数据类型,同时也是最最最重要的数据类型,基本上后面你无时不刻都在和它打交道。tensor具体的数学含义大家可以自行去了解。

关键报错代码:

    def __array__(self, dtype=None):
        if has_torch_function_unary(self):
            return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
        if dtype is None:
            return self.numpy()
        else:
            return self.numpy().astype(dtype, copy=False)

其中,传入的数据是cuda上的Tensor类型。


原因分析:

正如Numpy中所有的操作都是针对Numpy特有的变量类型Array,在Pytorch中几乎所有的操作都是针对Pytorch特有的变量类型Tensor。它们之间最大的不同是:Tensor既可以存储在GPU(或者’cuda’设备)上,也可以存储在cpu设备上(准确来讲是cpu的内存上);而Array只能存储在cpu的内存上。经过对问题的了解,可以确定问题就出在将tensor转化为numpy的过程中。


解决方案:

根据报错,我们可以看出原始数据是位于GPU(cuda)上的,所以在用.numpy()之前先得用.cpu()把它传到cpu上,修改后代码如下:

    def __array__(self, dtype=None):
        if has_torch_function_unary(self):
            return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
        if dtype is None:
            return self.cpu().numpy()
        else:
            return self.cpu().numpy().astype(dtype, copy=False)

Tensor.cpu()的作用是将cuda上的Tensor传回到cpu上。我的Tensor已经提前剔除梯度信息了,如果tensor本身还包含梯度信息,需要先把用.detach()梯度信息剔除掉,再转成numpy。如下:

Tensor.cpu().detach().numpy()

如果有错误,希望大家批评指正。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ppdd·~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值