Pytorch读取图片

    由Tensorflow转Pytorch,慢慢开始吧
    Pytorch的torchvision虽然包含了几个数据集的读取API,例如CAFRI10,Imagenet,MNIST等,但这远远不够。实际应用中,我们需要从各种不同的数据集中读取图片。
    Pytorch自定义读取数据的方式,主要用到两个类:
torch.utils.data.Datasettorch.utils.data.DataLoader
    为了自由读取数据集中的数据(图片),必须写一个Dataset的子类,该子类中必须overrider两个Dataset中的方法:__getitem__(self, index)__len__(self)
    前者是一个通过index索引来读取数据的方法,在这个方法中,可以利用torchvision.transforms来对数据做一些预处理。需要注意的是,该方法中产生的数据就是直接传递到DataLoader中的数据,因此,数据格式一定要是Tensor!
后者返回数据集的长度。
    最后调用DataLoader来形成batch,并可以进行shuffle等操作。
具体代码如下:

import cv2
import os
import numpy as np
import torch 
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader


name = 'C:\\Users\\Administrator\\Desktop\\history\\HED-BSDS\\test1.lst'
base_path = 'C:\\Users\\Administrator\\Desktop\\history\\HED-BSDS'

#首先定义一个Dataset的子类->myDataset
class myDataset(Dataset):
    def __init__(self, name, base_path):
        f = open(name)
        self.filenames = f.readlines()
        f.close()
 #override这两个方法
    def __getitem__(self, index):
        path = self.filenames[index]
        print(os.path.join(base_path, path))
        img = cv2.imread(os.path.join(base_path, path).strip())
        img = torch.Tensor(img)
        return img

    def __len__(self):
        return len(self.filenames)

dataset = myDataset

train_loader = DataLoader(dataset(name=name, base_path=base_path), 
    batch_size=4, shuffle=True)
for img in train_loader:
    print(img.size())
    cv2.imshow('we', np.uint8(img.numpy()[0]))
    cv2.waitKey()

最初级的读取方式就是这样。
另外,torchvision.data.ImageFolder也可以方便读取已经分类好的,放在不同文件夹中的图片。

  • 5
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值