pytorch构建自己的数据集

现在需要在json文件里面读取图片的URL和label,这里面可能会出现某些URL地址无效的情况。

python读取json文件

此处只需要将json文件里面的内容读取出来就可以了

with open("json_path",'r') ad load_f:
    load_dict = json.load(load_f)

json_path是json文件的地址,json文件里面的内容读取到load_dict变量中,变量类型为字典类型。

python通过URL打开图片

通过skimage获取URL图片是简单的方式。

from skimage import io
image = io.imread(img_src) # img_src是图片的URL
io.imshow(image)
io.show()

pytorch构建自己的数据集

pytorch中文网中有比较好的讲解: https://ptorch.com/news/215.html

加载图片预处理以及可视化见: https://oldpan.me/archives/pytorch-transforms-opencv-scikit-image

定义自己的数据集使用类 torch.utils.data.Dataset这个类,这个类中有三个关键的默认成员函数,__init__,__len__,__getitem__。

__init__类实例化应用,所以参数项里面最好有数据集的path,或者是数据以及标签保存的json、csv文件,在__init__函数里面对json、csv文件进行解析。

__len__需要返回images的数量。

__getitem__中要返回image和相对应的label,要注意的是此处参数有一个index,指的返回的是哪个image和label。

 

import torch
from torchvision import transforms 
import json
import os
from PIL import Image


class ProductDataset(torch.utils.data.Dataset):
    def __init__(self,json_path,data_path,transform = None,train = True):
        with open(json_path,'r') as load_f:
            self.json_dict = json.load(load_f)
        self.json_dict = self.json_dict["images"]
        self.train = train
        self.data_path = data_path
        self.transform = transform

    def __len__(self):
        return len(self.json_dict)

    def __getitem__(self,index):
        image_id = os.path.join(self.data_path + '/',str(self.json_dict[index]["id"]))
        image = Image.open(image_id)
        image = image.convert('RGB')
        label = int(self.json_dict[index]["class"])
        if self.transform:
            image = self.transform(image)
        if self.train:
            return image,label
        else:
            image_id = self.json_dict[index]["id"]
            return image,label,image_id


if __name__ == '__main__':
    val_dataset = ProductDataset('data/FullImageTrain.json','data/train',train=False,
                                transform=transforms.Compose([
                                    transforms.Pad(4),
                                    transforms.RandomResizedCrop(224),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                                ]))
    kwargs = {'num_workers': 4, 'pin_memory': True}
    test_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                                batch_size=32,
                                                shuffle=False,
                                                **kwargs)

    print(val_dataset.__len__())
    count = 0
    for image,label,image_id in test_loader:
        print(image.shape,count)
        count += 1

 

关于transform,图像预处理的各个函数功能介绍如下:

torch.transforms是常见的图像变换,可以用Compose连接起来。

下面是Transforms on PIL Image:

transforms.CenterCrop(size):

size可以是一个像(h,w)的sequence,这样输出的是一个中心裁剪的(h,w)图像。

transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):

随机更改图像的亮度,对比度和饱和度。

传递的参数是float型变量或者是tuple(元素是float型)型变量,如果是tuple型变量,第一个元素是min值,第二个元素是max值。

transforms.Grayscale(num_output_channels=1)

将Image转换为灰度值

transforms.Pad(padding, fill=0, padding_mode='constant')

padding这个参数,如果给定的是单个的值,那么会pad所有的边。

transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')

随机裁剪图片到给定尺寸

size如果是(h,w)这样的sequence,那么将剪出一个(h,w)大小的图片

transforms.RandomHorizontalFlip(p=0.5):

以给定的概率随机水平翻转给定的PIL图像。

transforms.RandomResizedCrop(size,scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)

将给定的图像随机裁剪为不同的大小和高宽比,然后缩放所裁剪的图像到指定大小。

 

该操作的含义:即使只是该物体的一部分,我们也认为这是该类物体。

scale为0.08到1的意思为裁剪的面积比例为0.08到1,注意是面积不是边,ratio是高宽比。
transforms.Resize(size, interpolation=2):

Resize给定的Image图像到指定大小。

size:给定图像大小

interpolation:差值方法,默认是PIL.Image.BILINEAR

下面是Transforms on torch.*Tensor:

transforms.Normalize(mean,var,inplace=False):

标准化图像,mean和var给定三个值的情况下,是分别对于RGB三个channel进行标准化。

 

转载于:https://www.cnblogs.com/yanxingang/p/10658124.html

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: PyTorch是一个流行的深度学习框架,它提供了许多工具和功能来处理各种类型的数据集。其中一种常见的数据格式是JSON(JavaScript Object Notation)。 JSON是一种轻量级的数据交换格式,它使用类似于字典的结构来表示数据。在PyTorch中,可以使用内置的json模块来读取和处理JSON数据集。 首先,我们需要使用Pythonjson库将JSON数据加载到内存中。可以使用`json.load()`函数来读取JSON文件,返回一个包含JSON数据的Python字典。例如,如果我们的JSON文件名为"dataset.json",可以使用以下代码加载数据集: ``` import json with open('dataset.json', 'r') as f: dataset = json.load(f) ``` 然后,我们可以根据数据集的结构使用PyTorch的功能来进一步处理数据。例如,如果我们有一个包含图像路径和标签的JSON数据集,可以使用PyTorch的`torchvision`模块来加载图像和标签,并进行预处理和转换: ``` from torchvision import transforms from PIL import Image transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) for data in dataset: image_path = data['image_path'] label = data['label'] image = Image.open(image_path) image = transform(image) # 进一步处理图像和标签 ``` 最后,我们可以使用PyTorch的DataLoader来创建一个可迭代的数据加载器,并在训练模型时使用该数据加载器。这个数据加载器可以提供按批次加载数据、数据随机排序等功能,从而方便地处理大规模的JSON数据集。 总而言之,PyTorch提供了丰富的功能来处理JSON数据集。我们可以使用json库加载和解析JSON数据,并使用PyTorch的功能来进一步处理和转换数据,以进行深度学习模型的训练和评估。 ### 回答2: PyTorch是一种流行的深度学习框架,它提供了各种各样的功能和工具来帮助我们构建和训练机器学习模型。而JSON(JavaScript Object Notation)是一种轻量级的数据交换格式,常用于在不同操作系统和编程语言之间传输和存储数据。在PyTorch中,我们可以使用JSON数据集来加载和预处理我们的训练和测试数据。 PyTorch提供了一个名为torchvision的库,其中包含了一些内置的数据集,例如MNIST、CIFAR-10等。然而,如果我们想使用自定义的JSON数据集,我们需要首先将数据转换为PyTorch所需的格式。 为了加载JSON数据集,我们可以使用PyTorchDataset类和DataLoader类。Dataset类是一个抽象类,我们需要继承它并实现__len__和__getitem__两个方法来定义我们自己的数据集类。在__getitem__方法中,我们可以使用Pythonjson库来读取和解析JSON文件,并将数据进行预处理。 在数据预处理阶段,我们可以使用PyTorch中的transforms模块来进行常见的图像处理操作,例如裁剪、缩放、旋转等。这些处理操作可以应用于加载的JSON数据,以使其适用于我们的模型。另外,我们还可以通过实现自定义的数据转换函数来进行更复杂的数据处理操作。 一旦我们完成了数据集的加载和预处理,我们可以使用DataLoader类来生成一个可迭代的数据加载器。该加载器将自动处理数据的批处理、随机排序和多线程加载等细节,方便我们在训练过程中高效地加载和处理数据。 总而言之,PyTorch提供了丰富的功能和工具来处理JSON数据集。我们可以使用Dataset和DataLoader类来加载和预处理数据,并利用transforms模块和自定义函数来进行数据转换。通过这些操作,我们可以轻松地准备我们的数据,并将其用于PyTorch模型的训练和评估过程中。 ### 回答3: PyTorch是一个开源的深度学习框架,用于构建和训练神经网络模型。在PyTorch中,我们可以使用JSON数据集来加载和处理数据。 JSON(JavaScript Object Notation)是一种常用的轻量级数据交换格式,它以易于阅读和编写的方式表示结构化数据。JSON数据集通常由一个JSON文件组成,其中包含一个或多个样本的信息。 首先,我们需要使用Pythonjson库来读取和解析JSON文件。通过使用`json.load()`方法,我们可以将JSON文件中的内容加载为Python字典类型的对象。然后,我们可以使用Python的字典操作来访问和处理数据。 在PyTorch中,我们通常将数据加载到自定义的Dataset类中。我们可以创建一个继承自torch.utils.data.Dataset的子类,并实现其中的`__getitem__()`和`__len__()`方法来处理数据。在`__getitem__()`方法中,我们可以通过索引来访问和提取具体的样本数据,并对其进行预处理。 一种常见的做法是将JSON文件中的每个样本表示为一个字典,在字典中存储不同的特征和对应的标签。我们可以在`__getitem__()`方法中使用`json_data[索引]`来访问特定索引处的样本,然后从字典中提取所需的数据。 另一种常见的做法是将JSON文件中的每个样本表示为一个字符串对象,其中包含所有样本的信息。我们可以使用`json.loads()`方法将字符串转换为JSON对象,并进一步提取所需的特征和标签信息。 最后,我们可以使用PyTorch的DataLoader类来批量加载和处理JSON数据集。通过设置参数如batch_size、shuffle等,我们可以灵活地配置数据加载和处理的方式。 总的来说,PyTorch提供了灵活而强大的工具,可以处理和加载JSON数据集。我们可以根据具体的需求,使用不同的方法来读取JSON文件并提取所需的数据,然后在自定义的Dataset类中进行进一步的处理和批量加载,以便用于训练和评估神经网络模型。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值