Pytorch常见报错、踩坑点总结(持续更新)

在利用PyTorch进行深度学习的路上踩坑点实在是太多了,因此打算总结一些,以便日后查阅使用。

一、transforms.ToTensor()

在运行下面一段程序的时候,发现报错提醒:

D:\Anaconda3\envs\py36\lib\site-packages\torchvision\datasets\mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at …\torch\csrc\utils\tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

import torch
import torchvision
from torchvision import transforms

#读取数据集
#transforms.ToTensor()将尺寸为 (H x W x C)
# 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True,
                                                transform=trans,
                                                download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False,
                                               transform=trans, download=True)

print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

翻译一下,大概意思就是:

给定的NumPy数组不可写,PyTorch不支持不可写张量。这意味着您可以使用张量写入底层(假定不可写)NumPy数组。在将数组转换为张量之前,可能需要复制数组以保护其数据或使其可写。

网上查找原因,在这个帖子中发现:

transform=torchvision.transforms.ToTensor()起到的作用是把PIL.Image或者numpy.narray数据类型转变为torch.FloatTensor类型,shape是CHW,数值范围缩小为[0.0, 1.0]。

在用数据读取的时候无需换成PIL的数据格式了,直接numpy格式就可以,无需转来转去,多此一举,transform可以直接处理numpy数据!

但是基于ToTensor()源码就是会将numpy转换为Tensor格式,下载的最新PyTorch也不是很了解,故暂且忽略这个用法警告,以后懂了回来填坑。同时也请有幸看到此贴的大佬指正(毕竟我还是个刚开始学习的小辣鸡)

二、torch.utils.data.DataLoader()的num_workers值的问题

若指定进程数为除0以外的其他数,如为4,运行下面一段代码:

import time
import torch
import torchvision
from torchvision import transforms

#读取数据集
#transforms.ToTensor()将尺寸为 (H x W x C)
# 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True,
                                                transform=trans,
                                              
  • 12
    点赞
  • 52
    收藏
    觉得还不错? 一键收藏
  • 14
    评论
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值