pytorch中的2种处理载入图像数据的方式

第一种:通过pytorch中内置的ImageFolder()方法

以下面数据集为例:
在这里插入图片描述

# 1.用torch自带的ImageFolder()函数制作数据源
data_dir = r'D:/Projects/Datasets/flower_data/'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
cat_to_name_file = r'D:/Projects/Datasets/flower_data/cat_to_name.json'

# 下面是'''for ResNet'''
data_transforms = {
    'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(224),#从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    ]),
    'valid': transforms.Compose([transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 下面是'''for inception-v3'''
data_transforms1 = {
    'train': transforms.Compose([
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
}

batch_size = 8

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes

#print(image_datasets)
print(dataloaders)
print(class_names)
print(dataset_sizes)

def showImage(tensor): 
    '''
    function:展示数据(tensor转换成numpy) 
    tensor: 原tensor格式图片 
    return: numpy格式图片 
    '''
    image = tensor.to("cpu").clone().detach()

    image = image.numpy().squeeze() 
    image = image.transpose(1,2,0) 
    image = image*np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))

    image = image.clip(0, 1) 
    return image    
fig = plt.figure(figsize=(20, 12))    
columns = 4    
rows = 2    
with open(cat_to_name_file, 'r') as f:

    cat_to_name = json.load(f)     
print(cat_to_name)     

dataiter = iter(dataloaders['valid'])

inputs, classes = next(iter(dataiter))

#print('inputs:', inputs)     
print('classes:', classes)     

for idx in range (columns*rows):     
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])

    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])

    plt.imshow(showImage(inputs[idx]))

plt.show()     

第二种:通过自定义Dataloader()来处理数据

以下面数据集为例:

#2. 通过自定义dataloader 
data_dir = r'D:/Projects/Datasets/flower_photos/'
data_transform = {        
    "train"  : transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),        
                                 transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
    "val": transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),      
                               transforms.ToTensor(),      
                               transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
class FlowerDataset(Dataset):    
    def __init__(self, root_dir, ann_file, transform):
        # 数据集根目录路径和标签文件目录   
        self.root_dir = root_dir   
        self.ann_file = ann_file   
        self.img_label = self.load_annotations() # 经过这个函数操作,获得一个字典{image_path:label}

        #self.img_label = {'sunflowers': 0, 'roses': 1, 'dandelion': 2, 'daisy': 3, 'tulips': 4, 'ddd':2}
        #print(self.img_label)
        # 将image的path单独放进一个list
        self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]
        self.lable = [label for label in list(self.img_label.values())]
        self.transform = transform
    # 返回一个数据和标签
    def __getitem__(self, idx):
        image = Image.open(self.img[idx])
        label = self.lable[idx]
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(label)) # label是一个list格式,先转换成numpy格式,再转换成torch格式
        return image, label

    def __len__(self):
        return len(self.img) # 通过img这个存放名字的列表来计算总共有多少个图片

    #以字典格式制作一个数据。也可以给函数起名为:read_split_data
    def load_annotations(self):
        with open(self.ann_file, encoding='utf-8') as f: #with里定义的变量没有作用域,不要被缩进代码块误导
            names_list = [x.strip().split(' ')[0] for x in f.readlines()] # 取到名字
            #将f文件的指针复位到第一行,添加aa.seek(0)这行代码到第二次迭代前就可以啦。 ,否则后面打开文件内容为空
            f.seek(0)
            labels_list = [x.strip().split('/')[0] for x in f.readlines()] # 取到标签
            # 这个label是一个字符串,把它转换成0,1,3,4,5对应。
            # 将labels_list里面重复的字符串去掉,只保留一个,然后依次编号(这个操作可以用set()方法,转换成集合,在转换成list)
            labels_num = list(set(labels_list))
            labels_type = {name: labels_num.index(name) for name in
                           labels_num}  # 将每个类添加数字类,方便做标签分类        # {'sunflowers': 0, 'roses': 1, 'dandelion': 2, 'daisy': 3, 'tulips': 4}
            data_infos = {k: v for k, v in zip(names_list,
                                               labels_list)}  # {'daisy/7568630428_8cf0fc16ff_n.jpg': 'sunflowers', 'daisy/7410356270_9dff4d0e2e_n.jpg': 'roses',
            data_infos = {k: labels_type[v] for k, v in data_infos.items()}  # 将标签转换成对应数字
        return data_infos

ann_file='D:/Projects/Datasets/flower_photos/annotations.txt'
train_dataset = FlowerDataset(root_dir='D:/Projects/Datasets/flower_photos', ann_file=ann_file, transform=data_transform["train"])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

#测试是否打包好了
image, label = next(iter(train_loader))
sample = image[1].squeeze() # 把第一个维度去掉(1*3*224*224)变成(3*244*224)
sample = sample.permute((1,2,0)).numpy() # 变成numpy展示图像(224*224*3)
plt.imshow(sample)
plt.show()
print('lable is {}'.format(label[0].numpy()))
print(image.shape)
  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch载入图像有多方法。一常用的方法是使用`ImageFolder`类和`DataLoader`类。首先,需要导入相关的包和设置转换操作。可以使用`torchvision.datasets.ImageFolder`导入`ImageFolder`类,然后使用`torchvision.transforms`导入`transforms`模块。接下来,可以使用`transforms.Compose`创建一个转换操作的组合,其包括将图片缩放至256x256像素、从心裁剪成224x224像素、以及将图片转换为Tensor类型并进行归一化的操作。\[1\] 然后,可以使用`ImageFolder`类加载数据集。将包含分类图片的父目录路径传递给`ImageFolder`类,并传入之前创建的转换操作。这样就可以得到要加载的数据集。接下来,可以使用`DataLoader`类加载数据集,并设置批量大小和是否打乱数据的参数。\[2\] 以下是一个示例代码,展示了如何使用`DataLoader`加载数据集并显示其的图片: ```python import torch from torchvision import transforms, datasets from torch.utils.data import DataLoader # 设置转换操作 transforms = transforms.Compose(\[ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor() \]) # 加载数据集 path = r'D:\dataset_deep_learning\Flower_Orig_dataset' data_train = datasets.ImageFolder(path, transform=transforms) # 使用DataLoader加载数据 data_loader = DataLoader(data_train, batch_size=64, shuffle=True) # 显示数据集的图片 import matplotlib.pyplot as plt import numpy as np import torchvision for i, data in enumerate(data_loader): images, labels = data img = torchvision.utils.make_grid(images).numpy() plt.imshow(np.transpose(img, (1, 2, 0))) plt.show() break ``` 这段代码首先导入了必要的包,并设置了转换操作。然后使用`ImageFolder`类加载数据集,并使用`DataLoader`类加载数据。最后,使用`matplotlib.pyplot`和`numpy`库显示数据集的图片。\[2\] 希望这个例子能帮助你理解如何在PyTorch载入图像数据集。如果你需要更多细节,可以参考我在文章提供的其他资料。\[3\] #### 引用[.reference_title] - *1* *2* [pytorch加载自己的图片数据集的两方法](https://blog.csdn.net/qq_53345829/article/details/124308515)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [【深度学习】数据准备-pytorch自定义图像分割类数据集加载](https://blog.csdn.net/adreammaker/article/details/126037510)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值