在pytorch实现神经网络的时候需要数据的读取,我们经常采用torch.utils.data.dataset.Dataset,但是在具体的数据读取的时候可以使用opencv或者PIL两种方法,他们在具体的数据处理上也是不一样的。
1、使用opencv:
# -*- coding: utf-8 -*-
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
import os
import cv2
from PIL import Image
class MyDataset(Dataset):
def __init__(self, transform=None):
self.transform = transforms.Compose([
transforms.ToTensor() # 这里仅以最基本的为例
])
self.image_path = './image_data2/'
self.image_names = os.listdir(self.image_path)
def __len__(self):
return len(self.image_names)
def __getitem__(self, item):
image_name = self.image_names[item]
image = cv2.imread(os.path.join(self.image_path, image_name)) # 读到的是BGR数据
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转化为RGB,也可以用img = img[:, :, (2, 1, 0)]
# 这时的image是H,W,C的顺序,因此下面需要转化为C, H, W
image = torch.from_numpy(image).permute(2, 0, 1)
# image = Image.open(os.path.join(self.image_path, image_name))
# # print(image.shape)
# image = self.transform(image)
return image
2、PIL
# -*- coding: utf-8 -*-
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
import os
import cv2
from PIL import Image
class MyDataset(Dataset):
def __init__(self, transform=None):
self.transform = transforms.Compose([
transforms.ToTensor() # 这里仅以最基本的为例
])
self.image_path = './image_data2/'
self.image_names = os.listdir(self.image_path)
def __len__(self):
return len(self.image_names)
def __getitem__(self, item):
image_name = self.image_names[item]
# image = cv2.imread(os.path.join(self.image_path, image_name)) # 读到的是BGR数据
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转化为RGB
# # 这时的image是H,W,C的顺序,因此下面需要转化为C, H, W
# image = torch.from_numpy(image).permute(2, 0, 1)
image = Image.open(os.path.join(self.image_path, image_name)) # 读取到的是RGB, C, H, W
image = self.transform(image) # 转化为tensorjavascript:void(0)
return image