pytorch是一个很好用的工具,对于数据最好使用它封装好的dataset、dataLoader结构。这样既能够让代码有更好的可读性,主要是保证内存释放的问题。
对于自定义的数据集,必须要继承dataset类并结合自定义的数据来重写它的两个成员方法。
相关链接:
1.https://blog.csdn.net/luolinll1212/article/details/82871729
2.https://blog.csdn.net/yt4766269/article/details/77923422
一、代码示例
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from PIL import Image
'''
#自定义的Dataset需要继承它并实现它的两个成员方法
#自定义Dataset只能够使用自定义的Transform步骤,重点是__getitem__函数
#transform函数是在pil.image上操作的,所以需要将img(numpy的数组)转化为转化为PIL Image类型
#transform函数是在pil.image上操作的,所以需要将img(numpy的数组)转化为转化为PIL Image类型
'''
class dataset_view1(Dataset):
def __init__(self,data,name,transform=None,datalen=None):
self.data=data
self.name=name
self.transform=transform
self.datalen=datalen
#支持索引,dataset[i]可以用来获取i样本
def __getitem__(self, index):
label = self.data[1][index]
tag = self.data[2][index]
img=self.data[0][index]
img= Image.fromarray(img, mode='RGB')
#Image.fromarray(np.uint8(img))
if self.transform is not None:
img = self.transform(img)
return img, label,tag
#返回数据集的大小
def __len__(self):
datalen=len(self.data[0])
return datalen
二、涉及的相关知识
*自定义Dataset只能够使用自定义的Transform步骤
- pil image.fromarray
transform函数是在pil.image上操作的,所以需要将img(numpy的数组)转化为转化为PIL Image类型。
#将numpy的数组转化为PIL Image类型
a=np.zeros((256,256,3))
b = Image.fromarray(a, mode='RGB')
b.show()
- mode
数字图像处理中,针对不同的图像格式有其特定的处理算法。所以,在做图像处理之前,我们需要考虑清楚自己要基于哪种格式的图像进行算法设计及其实现。
(1) 灰度图模式为“L”,shape=[256,256,1],channel=1 ,即 "mode=RGB"表示3通道
对于灰度图像,不管其图像格式是PNG,还是BMP,或者JPG,打开后,其模式为“L”。
(2)彩色图模式为RGB,shape=[256,256,3],channel=3,即"mode=L"表示1通道
对于彩色图像,不管其图像格式是PNG,还是BMP,或者JPG,在PIL中,使用Image模块的open()函数打开后,返回的图像对象的模式都是“RGB”
(3)RGBA (4x8-bit pixels, true color with transparency mask)
(4)模式之间的相互转换–Image模块的convert()函数
-
img=img.convert(‘mode’)
-
需要注意数据的shape与mode的一一对应关系
-
相关链接:https://blog.csdn.net/zxyhhjs2017/article/details/80924210
-
读取到的numpy的ndarray结构的数据不需要显式转换成pytorch的Tensor,后续的DataLoader会自动替你转换。