在利用PyTorch进行深度学习的路上踩坑点实在是太多了,因此打算总结一些,以便日后查阅使用。
一、transforms.ToTensor()
在运行下面一段程序的时候,发现报错提醒:
D:\Anaconda3\envs\py36\lib\site-packages\torchvision\datasets\mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at …\torch\csrc\utils\tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
import torch
import torchvision
from torchvision import transforms
#读取数据集
#transforms.ToTensor()将尺寸为 (H x W x C)
# 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False,
transform=trans, download=True)
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))
翻译一下,大概意思就是:
给定的NumPy数组不可写,PyTorch不支持不可写张量。这意味着您可以使用张量写入底层(假定不可写)NumPy数组。在将数组转换为张量之前,可能需要复制数组以保护其数据或使其可写。
网上查找原因,在这个帖子中发现:
transform=torchvision.transforms.ToTensor()起到的作用是把PIL.Image或者numpy.narray数据类型转变为torch.FloatTensor类型,shape是CHW,数值范围缩小为[0.0, 1.0]。
在用数据读取的时候无需换成PIL的数据格式了,直接numpy格式就可以,无需转来转去,多此一举,transform可以直接处理numpy数据!
但是基于ToTensor()源码就是会将numpy转换为Tensor格式,下载的最新PyTorch也不是很了解,故暂且忽略这个用法警告,以后懂了回来填坑。同时也请有幸看到此贴的大佬指正(毕竟我还是个刚开始学习的小辣鸡)
二、torch.utils.data.DataLoader()的num_workers值的问题
若指定进程数为除0以外的其他数,如为4,运行下面一段代码:
import time
import torch
import torchvision
from torchvision import transforms
#读取数据集
#transforms.ToTensor()将尺寸为 (H x W x C)
# 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True,
transform=trans,