from PIL import Image
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset,DataLoader
import numpy as np
from torchvision import transform
import cv2
import os
Class TorchDataSet(Dataset):
def __init__(self, filename, image_dir, repeat, transform = None):
"""
:param filename:数据txt文件:格式: image_name.jpg label_id
:param image_dir:图片路径:image_dir+image_name.jpg构成完整图片路径
:param repeat:所有样本数据重复次数,默认循环一次,当repeat为None时,表示无限循环
"""
self.image_label_list = self.read_file(filename)
self.image_dir = image_dir
self.len = len(self.image_label_list)
self.repeat = repeat
self.transform = transform
def __getiem__(self,i):
index = i%self.len
image_name, label = self.image_label_list[index]
image_path = os.join.path(self.image_dir, image_name)
img = self.load_data(image_path)
if self.transform is not None:
img = self.transform(img)
label = np.array(label)
return image_name, img, label
def __len__(self):
if repest=None:
data_len = 10000000
else:
data_len = len(self.image_label_list)*self.repeat
def read_file(self, filename):
image_label_list = []
with open(filename, "r") as f:
lines = f.readlines()
for line in lines:
# rstrip:用来去除解维的字符,空白符(包括\n,\r,\t," ",即换行,回车,制表符,空格)
content = line.rstrip().split(" ")
name = content[0]
# 此处label需要强制转换为int型,由txt读取的为str型
label = int(content[1])
image_label_list.append((name, labels))
return image_label_list
def load_data(self, filename):
bgr = cv2.imead(filrname)
if bgr_image is None:
print("Warning:不存在:{}", filename)
return None
if len(bgr_image.shape) == 2:
#将灰度图转换为rgb
print("Warning: gray image", filename)
bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
rgb_image = Image.fromarray(rgb_image)
return rgb_image
def data_processing(self, data):
data = self.toTensor(data)
return data
完整的用法为下列步骤
if __name__ == "__main__":
test_filename ="./test.txt"
image_dir = "./test"
#括号里的transforms=一定要有,不然会莫名其妙报错
trans = transforms.Compose(transforms = [
transforms.Resize(int(341)),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize(
mean = [0.5, 0.5, 0.5],
std = [0.1, 0.1, 0.1])
]
)
test_data =TorchDataset(filename=test_filenameimage_dir=image_dir, repeat=1,transform=trans)
test_loader = DataLoader(dataset=test_data,batch_size=1,shuffle=False)