pytorch中自建数据集dataset的书写方式以及读取方式

from PIL import Image
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset,DataLoader
import numpy as np
from torchvision import transform
import cv2
import os
Class  TorchDataSet(Dataset):
	def __init__(self, filename, image_dir, repeat, transform = None):
	"""
	:param filename:数据txt文件:格式: image_name.jpg label_id
	:param image_dir:图片路径:image_dir+image_name.jpg构成完整图片路径
	:param repeat:所有样本数据重复次数,默认循环一次,当repeat为None时,表示无限循环
	"""
	self.image_label_list = self.read_file(filename)
	self.image_dir = image_dir
	self.len = len(self.image_label_list)
	self.repeat = repeat
	self.transform = transform
def __getiem__(self,i):
	index = i%self.len
	image_name, label = self.image_label_list[index]
	image_path = os.join.path(self.image_dir, image_name)
	img = self.load_data(image_path)
	if self.transform is not None:
		img = self.transform(img)
	label = np.array(label)
	return image_name, img, label
def __len__(self):
	if repest=None:
		data_len = 10000000
	else:
		data_len = len(self.image_label_list)*self.repeat
def read_file(self, filename):
	image_label_list = []
	with open(filename, "r") as f:
		lines = f.readlines()
		for line in lines:
			# rstrip:用来去除解维的字符,空白符(包括\n,\r,\t," ",即换行,回车,制表符,空格)
			content = line.rstrip().split(" ")
			name = content[0]
			# 此处label需要强制转换为int型,由txt读取的为str型
			label = int(content[1])
			image_label_list.append((name, labels))
	return image_label_list
def load_data(self, filename):
	bgr = cv2.imead(filrname)
	if bgr_image is None:
		print("Warning:不存在:{}", filename)
		return None
	if len(bgr_image.shape) == 2:
		#将灰度图转换为rgb
		print("Warning: gray image", filename)
	bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
	rgb_image = Image.fromarray(rgb_image)
 	return rgb_image
def data_processing(self, data):
	data = self.toTensor(data)
	return data

完整的用法为下列步骤

if __name__ == "__main__":
	test_filename ="./test.txt"
	image_dir = "./test"
	#括号里的transforms=一定要有,不然会莫名其妙报错
	trans = transforms.Compose(transforms = [
					transforms.Resize(int(341)),
					transforms.CenterCrop(299),
					transforms.ToTensor(),
					transforms.Normalize(
					mean = [0.5, 0.5, 0.5],
					std = [0.1, 0.1, 0.1])
	]
	)
	
	test_data =TorchDataset(filename=test_filenameimage_dir=image_dir, repeat=1,transform=trans)
	test_loader = DataLoader(dataset=test_data,batch_size=1,shuffle=False)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值