自定义数据集
Pytorch将数据集的处理过程标准化。
数据加载的具体过程:
- 继承Dataset类
Pytorch中提供了torch.utils.data.Dataset抽象类,使用时需要继承这个类,并重写__len__和__geiitem__函数。 - 增加数据变换
Pytorch提供了torchvision.transforms可以比较方便进行图像的缩放、裁剪、随机旋转、填充及张量的归一化操作等,操作对象是PIL的Image或者Tensor。可以使用transforms.Compose将多个变换整合。使用的时候一般集成到Dataset的继承类中。 - 继承DataLoader
需要进行批量处理、随机选取等等,所以还需要这一步。
代码
import argparse
import os
import glob
import csv
import PIL
import visdom
import matplotlib.pyplot as plt
import torch
import time
from torch.utils.data import Dataset
from torchvision import transforms
class MyData(Dataset):
def __init__(self,root,transform=None):
super(MyData,self).__init__()
self.root=root
self.transform=transform
self.name2label={} # 映射
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name]=len(self.name2label.keys())
#{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
print(self.name2label)
# image+label
self.images,self.labels=self.load_csv('Image2Label.csv')
def load_csv(self,fliename):
if not os.path.exists(os.path.join(self.root,fliename)):
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,'*.jpg'))
print(len(images),images)
with open(os.path.join(self.root,fliename),mode='w',newline='') as f:
writer=csv.writer(f)
for img in images:#'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg'
name=img.split(os.sep)[-2]
label=self.name2label[name]
# 'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg', '0'
writer.writerow([img,label])
print('writen into csv file:',fliename)
images,labels=[],[]
with open(os.path.join(self.root,fliename)) as f:
reader=csv.reader(f)
for row in reader:
# 'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg', '0'
img,label=row
label=int(label)
images.append(img)
labels.append(label)
assert len(images)==len(labels)
return images,labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
#idx[0,len(images)]
#img:'D:\\Projects\\DeepLearning\\Dataset\\flower_photos\\val\\daisy\\00000.jpg'
#label:0
input_image,input_label=self.images[idx],self.labels[idx]
#路径->图像数据类型
input_image=PIL.Image.open(input_image).convert('RGB')
if self.transform:
input_image=self.transform(input_image)
return input_image,input_label
def main():
parser = argparse.ArgumentParser(description='训练参数')
parser.add_argument('--batchsize', type=int, default=20, help='The number of batch_size')
parser.add_argument('--epochs', type=int, default=20, help='The number of epochs')
args = parser.parse_args()
viz=visdom.Visdom() #将一个窗口类实例化
image_path = r'D:\Projects\DeepLearning\Dataset\flower_photos\train'
tf=transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor()
])
sample=MyData(image_path,tf) #sample=MyData(image_path,None)
#x,y=next(iter(sample))
#viz.image(x,win='sample_x',opts=dict(title='sample_x'))
train_loader=torch.utils.data.DataLoader(sample,batch_size=args.batchsize,shuffle=True)
#必须将图片大小提前调整为一样才可以显示
for x,y in train_loader:
viz.images(x,nrow=5,win='batch',opts=dict(title='batch'))
viz.text(str(y),win='label',opts=dict(title='haha'))
time.sleep(10)
if __name__ =='__main__':
main()
python -m visdom.server