PyTorch—ImageFolder/自定义类 读取图片数据


运行环境安装 Anaconda | python ==3.6.6

conda install pytorch -c pytorch
pip install config
pip install tqdm             #包装迭代器,显示进度条
pip install torchvision
pip install scikit-image

一、torchvision 图像数据读取 [0, 1]

import torchvision.transforms as transforms
transforms 模块提供了一般的图像转换操作类。
class torchvision.transforms.ToTensor
功能:
把shape=(H x W x C) 的像素值为 [0, 255] 的 PIL.Image 和 numpy.ndarray
转换成shape=(C x H x W)的像素值范围为[0.0, 1.0]的 torch.FloatTensor。

class torchvision.transforms.Normalize(mean, std)
功能:
此转换类作用于torch.*Tensor。给定均值(R, G, B)和标准差(R, G, B),用公式channel = (channel - mean) / std进行规范化。

import torchvision 
import torchvision.transforms as transforms 
import cv2 
import numpy as np 
from PIL import Image 

img_path = "./data/timg.jpg" 

# 引入transforms.ToTensor()功能: range [0, 255] -> [0.0,1.0] 
transform1 = transforms.Compose([transforms.ToTensor()])

# 直接读取:numpy.ndarray 
img = cv2.imread(img_path)
print("img = ", img[0])      #只输出其中一个通道
print("img.shape = ", img.shape)

# 归一化,转化为numpy.ndarray并显示
img1 = transform1(img) 
img2 = img1.numpy()*255 
img2 = img2.astype('uint8') 
img2 = np.transpose(img2 , (1,2,0)) 
 
print("img1 = ", img1)
cv2.imshow('img2 ', img2 ) 
cv2.waitKey() 


# PIL 读取图像
img = Image.open(img_path).convert('RGB') # 读取图像 
img2 = transform1(img) # 归一化到 [0.0,1.0] 
print("img2 = ",img2) #转化为PILImage并显示 
img_2 = transforms.ToPILImage()(img2).convert('RGB') 
print("img_2 = ",img_2) 
img_2.show()


从上到下依次输出:---------------------------------------------
img =   [[197 203 202]
	 [195 203 202]
	 ...
	 [200 208 207]
	 [200 208 207]]
img.shape =  (362, 434, 3)

img1 =  tensor([[[0.7725, 0.7647, 0.7686,  ..., 0.7804, 0.7843, 0.7843],
         [0.7765, 0.7725, 0.7686,  ..., 0.7686, 0.7608, 0.7569],
         [0.7843, 0.7725, 0.7686,  ..., 0.7725, 0.7686, 0.7569],
         ...,

img_transform =  tensor([[[0.7922, 0.7922, 0.7961,  ..., 0.8078, 0.8118, 0.8118],
         [0.7961, 0.8000, 0.7961,  ..., 0.7922, 0.7882, 0.7843],
         [0.8039, 0.8000, 0.7961,  ..., 0.8118, 0.8039, 0.7922],
         ...,

在这里插入图片描述
transforms.Compose 归一化到 [-1.0, 1.0 ]

transform2 = transforms.Compose([transforms.ToTensor()])
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))]) 

二、torchvision 的 Transform

在深度学习时关于图像的数据读取:由于Tensorflow不支持与numpy的无缝切换,导致难以使用现成的pandas等格式化数据读取工具,造成了很多不必要的麻烦,而pytorch解决了这个问题。

pytorch自定义读取数据和进行Transform的部分请见文档:
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

但是按照文档中所描述所完成的自定义Dataset只能够使用自定义的Transform步骤,而torchvision包中已经给我们提供了很多图像transform步骤的实现,为了使用这些已经实现的Transform步骤,我们可以使用如下方法定义Dataset:

from __future__ import print_function, division 
import os 
import torch 
import pandas as pd 
from PIL import Image 
import numpy as np 
from torch.utils.data import Dataset, DataLoader 
from torchvision import transforms 

class FaceLandmarkDataset(Dataset): 
    def __len__(self) -> int: 
        return len(self.landmarks_frame)
		
    def __init__(self, csv_file: str, root_dir: str, transform=None) -> None: 
        super().__init__() 
        self.landmarks_frame = pd.read_csv(csv_file) 
        self.root_dir = root_dir 
        self.transform = transform 

    def __getitem__(self, index:int): 
        img_name = self.landmarks_frame.ix[index, 0] 
        img_path = os.path.join('./faces', img_name) 
        with Image.open(img_path) as img: 
            image = img.convert('RGB') 
        landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float') 
        landmarks = np.reshape(landmarks,newshape=(-1,2)) 
        if self.transform is not None: 
            image = self.transform(image) 
        return image, landmarks 

########################以上为数据读取类(返回:image,landmarks)###############################
trans = transforms.Compose(transforms = [transforms.RandomSizedCrop(size=128), 
                                         transforms.ToTensor()]) 

face_dataset = FaceLandmarkDataset(csv_file='faces/face_landmarks.csv', 
				   root_dir='faces', transform= trans) 
loader = DataLoader(dataset = face_dataset, 
                    batch_size=4,
		    shuffle=True,
		    num_workers=4)

三、读取图像数据类

3.1 class torchvision.datasets.ImageFolder 默认读取图像数据方法:
  • __init__( 初始化)
    • classes, class_to_idx = find_classes(root) :得到分类的类别名(classes)和类别名与数字类别的映射关系字典(class_to_idx)
      其中 classes (list): List of the class names.
      其中 class_to_idx (dict): Dict with items (class_name, class_index).
    • imgs = make_dataset(root, class_to_idx)得到imags列表。
      其中 imgs (list): List of (image path, class_index) tuples
      每个值是一个tuple,每个tuple包含两个元素:图像路径和标签
  • __getitem__(图像获取)
    • path, target = self.imgs[index] 获取图像(路径,标签)
    • img = self.loader(path)数据读取。
    • img = self.transform(img)数据、标签 转换成 tensor
    • target = self.target_transform(target)
  • __len__( 数据集数量)
    • return len(self.imgs)
class ImageFolder(data.Dataset):
    """默认图像数据目录结构
    root
    .
    ├──dog
    |   ├──001.png
    |   ├──002.png
    |   └──...
    └──cat  
    |   ├──001.png
    |   ├──002.png
    |   └──...
    └──...
    """
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader):
        classes, class_to_idx = find_classes(root)
        imgs = make_dataset(root, class_to_idx)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        """
        index (int): Index
	Returns:tuple: (image, target) where target is class_index of the target class.
        """
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def __len__(self):
        return len(self.imgs)

图像获取 __getitem__ 中,self.loader(path) 采用的是default_loader,如下

def pil_loader(path):    # 一般采用pil_loader函数。
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)
3.2 自定义数据读取方法

PyTorch中和数据读取相关的类都要继承一个基类:torch.utils.data.Dataset。
故需要改写其中的 __init__、__len__、__getitem__ 等三个方法即可。

  • __init__()初始化传入参数:
    • img_path 里面为所有图像数据(包括训练和测试)
      txt_path 里面有 train.txt和val.txt两个文件:txt文件中每行都是图像路径,tab键,标签。
    • 其中 self.img_name 和 self.img_label 的读取方式就跟你数据的存放方式有关(需要调整的地方)
  • __getitem__()依然采用default_loader方法来读取图像。
  • Transform中将每张图像都封装成 Tensor
class customData(Dataset):
    def __init__(self, img_path, txt_path, dataset = '',data_transforms=None, loader = default_loader):
        with open(txt_path) as input_file:
            """
	    关于json文件解析:
	    https://blog.csdn.net/wsp_1138886114/article/details/83302339
	    txt文件解析如下,具体文本解析具体分析,没有定数
            """
            lines = input_file.readlines()
            self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
            self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
        self.data_transforms = data_transforms
        self.dataset = dataset
        self.loader = loader

    def __len__(self):
        return len(self.img_name)

    def __getitem__(self, item):
        img_name = self.img_name[item]
        label = self.img_label[item]
        img = self.loader(img_name)

        if self.data_transforms is not None:
            try:
                img = self.data_transforms[self.dataset](img)
            except:
                print("Cannot transform image: {}".format(img_name))
        return img, label
#####################以上为图像数据读取,返回(img, label)#########################

# 保证image_datasets与torchvision.datasets.ImageFolder类返回的数据类型一样
image_datasets = {x: customData(img_path='/ImagePath',
                                txt_path=('/TxtFile/' + x + '.txt'),
                                data_transforms=data_transforms,
                                dataset=x) for x in ['train', 'val']}

#用torch.utils.data.DataLoader类,将这个batch的图像数据和标签都分别封装成Tensor。
dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                             batch_size=batch_size,
                                             shuffle=True) for x in ['train', 'val']}

# 模型保存
torch.save(model, 'output/resnet_epoch{}.pkl'.format(epoch))

https://pytorch-cn.readthedocs.io/zh/latest/package_references/data/#torchutilsdata
鸣谢
https://www.cnblogs.com/denny402/p/5096001.html
https://blog.csdn.net/VictoriaW/article/details/72822005
https://blog.csdn.net/hao5335156/article/details/80593349
https://blog.csdn.net/u014380165/article/details/78634829

  • 53
    点赞
  • 220
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
### 回答1: 在 PyTorch 中读取自定义数据集的一般步骤如下: 1. 定义数据:首先需要定义一个数据,继承自 `torch.utils.data.Dataset` ,并实现 `__getitem__` 和 `__len__` 方法。在 `__getitem__` 方法中,根据索引返回一个样本的数据和标签。 2. 加载数据集:使用 `torch.utils.data.DataLoader` 加载数据集,可以设置批量大小、多线程读取数据等参数。 下面是一个简单的示例代码,演示如何使用 PyTorch 读取自定义数据集: ```python import torch from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, data, targets): self.data = data self.targets = targets def __getitem__(self, index): x = self.data[index] y = self.targets[index] return x, y def __len__(self): return len(self.data) # 加载训练集和测试集 train_data = ... train_targets = ... train_dataset = CustomDataset(train_data, train_targets) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_data = ... test_targets = ... test_dataset = CustomDataset(test_data, test_targets) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 训练模型 for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(train_loader): # 前向传播、反向传播,更新参数 ... ``` 在上面的示例代码中,我们定义了一个 `CustomDataset` ,加载了训练集和测试集,并使用 `DataLoader` 分别对它们进行批量读取。在训练模型时,我们可以像使用 PyTorch 自带的数据集一样,循环遍历每个批次的数据和标签,进行前向传播、反向传播等操作。 ### 回答2: PyTorch是一个开源的深度学习框架,它提供了丰富的功能用于读取和处理自定义数据集。下面是一个简单的步骤来读取自定义数据集。 首先,我们需要定义一个自定义数据,该应继承自`torch.utils.data.Dataset`,并实现`__len__`和`__getitem__`方法。`__len__`方法应返回数据集的样本数量,`__getitem__`方法根据给定索引返回一个样本。 ```python import torch from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] return torch.tensor(sample) ``` 接下来,我们可以创建一个数据集实例并传入自定义数据。假设我们有一个包含多个样本的列表 `data`。 ```python data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] dataset = CustomDataset(data) ``` 然后,我们可以使用`torch.utils.data.DataLoader`加载数据集,并指定批次大小、是否打乱数据等。 ```python batch_size = 2 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) ``` 现在,我们可以迭代数据加载器来获取批次的样本。 ```python for batch in dataloader: print(batch) ``` 上面的代码将打印出两个批次的样本。如果`shuffle`参数设置为`True`,则每个批次的样本将是随机的。 总而言之,PyTorch提供了简单而强大的工具来读取和处理自定义数据集,可以根据实际情况进行适当修改和扩展。 ### 回答3: PyTorch是一个流行的深度学习框架,可以用来训练神经网络模型。要使用PyTorch读取自定义数据集,可以按照以下几个步骤进行: 1. 准备数据集:将自定义数据集组织成合适的目录结构。通常情况下,可以将数据集分为训练集、验证集和测试集,每个集合分别放在不同的文件夹中。确保每个文件夹中的数据按照别进行分,以便后续的标签处理。 2. 创建数据加载器:在PyTorch中,数据加载器是一个有助于有效读取和处理数据。可以使用`torchvision.datasets.ImageFolder`创建一个数据加载器对象,通过传入数据集的目录路径来实现。 3. 数据预处理:在将数据传入模型之前,可能需要对数据进行一些预处理操作,例如图像变换、标准化或归一化等。可以使用`torchvision.transforms`中的来实现这些预处理操作,然后将它们传入数据加载器中。 4. 创建数据迭代器:数据迭代器是连接数据集和模型的重要接口,它提供了一个逐批次加载数据的功能。可以使用`torch.utils.data.DataLoader`创建数据迭代器对象,并设置一些参数,例如批量大小、是否打乱数据等。 5. 使用数据迭代器:在训练时,可以使用Python的迭代器来遍历数据集并加载数据。通常,它会在每个迭代步骤中返回一个批次的数据和标签。可以通过`for`循环来遍历数据迭代器,并在每个步骤中处理批次数据和标签。 这样,我们就可以在PyTorch中成功读取并处理自定义数据集。通过这种方式,我们可以更好地利用PyTorch的功能来训练和评估自己的深度学习模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

SongpingWang

你的鼓励是我创作的最大动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值