python构建自己的数据集_pytorch构建自己的数据集

现在需要在json文件里面读取图片的URL和label,这里面可能会出现某些URL地址无效的情况。

python读取json文件

此处只需要将json文件里面的内容读取出来就可以了

with open("json_path",'r') ad load_f:

load_dict= json.load(load_f)

json_path是json文件的地址,json文件里面的内容读取到load_dict变量中,变量类型为字典类型。

python通过URL打开图片

通过skimage获取URL图片是简单的方式。

from skimage importio

image= io.imread(img_src) #img_src是图片的URL

io.imshow(image)

io.show()

pytorch构建自己的数据集

pytorch中文网中有比较好的讲解: https://ptorch.com/news/215.html

加载图片预处理以及可视化见: https://oldpan.me/archives/pytorch-transforms-opencv-scikit-image

定义自己的数据集使用类 torch.utils.data.Dataset这个类,这个类中有三个关键的默认成员函数,__init__,__len__,__getitem__。

__init__类实例化应用,所以参数项里面最好有数据集的path,或者是数据以及标签保存的json、csv文件,在__init__函数里面对json、csv文件进行解析。

__len__需要返回images的数量。

__getitem__中要返回image和相对应的label,要注意的是此处参数有一个index,指的返回的是哪个image和label。

importtorchfrom torchvision importtransformsimportjsonimportosfrom PIL importImageclassProductDataset(torch.utils.data.Dataset):def __init__(self,json_path,data_path,transform = None,train =True):

with open(json_path,'r') as load_f:

self.json_dict=json.load(load_f)

self.json_dict= self.json_dict["images"]

self.train=train

self.data_path=data_path

self.transform=transformdef __len__(self):returnlen(self.json_dict)def __getitem__(self,index):

image_id= os.path.join(self.data_path + '/',str(self.json_dict[index]["id"]))

image=Image.open(image_id)

image= image.convert('RGB')

label= int(self.json_dict[index]["class"])ifself.transform:

image=self.transform(image)ifself.train:returnimage,labelelse:

image_id= self.json_dict[index]["id"]returnimage,label,image_idif __name__ == '__main__':

val_dataset= ProductDataset('data/FullImageTrain.json','data/train',train=False,

transform=transforms.Compose([

transforms.Pad(4),

transforms.RandomResizedCrop(224),

transforms.RandomHorizontalFlip(),

transforms.ToTensor(),

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

]))

kwargs= {'num_workers': 4, 'pin_memory': True}

test_loader= torch.utils.data.DataLoader(dataset=val_dataset,

batch_size=32,

shuffle=False,**kwargs)print(val_dataset.__len__())

count=0for image,label,image_id intest_loader:print(image.shape,count)

count+= 1

1546967-20190529190418741-1850015310.png

关于transform,图像预处理的各个函数功能介绍如下:

torch.transforms是常见的图像变换,可以用Compose连接起来。

下面是Transforms on PIL Image:

transforms.CenterCrop(size):

size可以是一个像(h,w)的sequence,这样输出的是一个中心裁剪的(h,w)图像。

transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):

随机更改图像的亮度,对比度和饱和度。

传递的参数是float型变量或者是tuple(元素是float型)型变量,如果是tuple型变量,第一个元素是min值,第二个元素是max值。

transforms.Grayscale(num_output_channels=1)

将Image转换为灰度值

transforms.Pad(padding, fill=0, padding_mode='constant')

padding这个参数,如果给定的是单个的值,那么会pad所有的边。

transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')

随机裁剪图片到给定尺寸

size如果是(h,w)这样的sequence,那么将剪出一个(h,w)大小的图片

transforms.RandomHorizontalFlip(p=0.5):

以给定的概率随机水平翻转给定的PIL图像。

transforms.RandomResizedCrop(size,scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)

将给定的图像随机裁剪为不同的大小和高宽比,然后缩放所裁剪的图像到指定大小。

该操作的含义:即使只是该物体的一部分,我们也认为这是该类物体。

scale为0.08到1的意思为裁剪的面积比例为0.08到1,注意是面积不是边,ratio是高宽比。

transforms.Resize(size, interpolation=2):

Resize给定的Image图像到指定大小。

size:给定图像大小

interpolation:差值方法,默认是PIL.Image.BILINEAR

下面是Transforms on torch.*Tensor:

transforms.Normalize(mean,var,inplace=False):

标准化图像,mean和var给定三个值的情况下,是分别对于RGB三个channel进行标准化。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
构建PyTorch数据集一般要以下步骤: 1. 定义数据集类:自定义数据集类,继承PyTorch的Dataset类,并实现__len__和__getitem__两个方法,分别用于获取数据集大小和获取数据集中的单个样本。 2. 数据预处理:根据求对数据进行预处理,如图像数据要进行归一化、裁剪、缩放等操作。 3. 数据增强:为了增加数据集的多样性,可以对数据进行旋转、平移、翻转等变换操作。 4. 数据加载器:使用PyTorch提供的DataLoader类,可以将数据集加载到内存中,实现批量处理和多线程加速。 下面是一个简单的示例代码,用于构建一个人脸数据集: ```python import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms class FaceDataset(Dataset): def __init__(self, data_path): self.data = torch.load(data_path) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def __len__(self): return len(self.data) def __getitem__(self, idx): img = self.data[idx]['img'] label = self.data[idx]['label'] img = self.transform(img) return img, label if __name__ == '__main__': dataset = FaceDataset('face_data.pt') dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) ``` 以上代码中,FaceDataset类继承了Dataset类,并实现了__len__和__getitem__方法。在初始化方法中,我们加载了数据集,并定义了数据预处理操作。在__getitem__方法中,我们返回了一个样本和其对应的标签。最后,我们使用DataLoader类将数据集加载到内存中,并定义了批量大小、是否打乱、以及多线程数量。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值