** 很多东西如果学过去,可能真的学过去了,当你需要用的时候,可能你不记得你学过;但是你面对不懂的知识,突然发现你学过,这样可能效果更好。不是为了学习而学习**
Pytorch提供了读取数据和对数据进行预处理的方法和类(可以简单地理解常用的方法和类):
1.torch.utils.data.Dataset提供的类:Dataset,它是一个抽象类,那么继承和重写这个类就可以定义自己的数据类,只需要定义__len__(),和__getitem__()这两个函数就可以了,下面是一个伪代码,简单的说一下怎样定义自己的数据类:
from torch.utils.data import Dataset
import pandas as pd # pandas库提供了读取csv文件的函数read_csv()
class myDataset(Dataset): # 定义自己的数据类myDataset,继承的抽象类Dataset
def __init__(self, csv_file, txt_file,root_dir,other_file): # csv_file:抽象的表示.csv文件;txt_file:抽象的表示txt文件;
# root_dir:地址,这些参数放在初始化函数里
self.csv_data= pd.read_csv(csv_file) # 读取csv文件,并且赋给他本身
with open(txt_file,'r') as f: # 读取txt文件,并且赋给他本身,读取的方式为:with open(...) as f:
data_list = f.readlines() # 读取每一行数据,并且放到data_list里
self.txt_data = data_list
self.root_dir = root_dir
# 实现下面这个方法:
def __len__(self): # 定义自己的数据类,必须重写这个方法(函数)
return len(self.csv_data) # 返回的数据的长度
def __getitem__(self, idx): # 定义自己的数据类,必须重写这个方法(函数)
data = (self.csv_data[idx],self.txt_data[idx]) # 获取数据的方式,按照索引进行的
return data
2.torch.utils.data已经提供的类:Dataset,但是通过这种方式只能一个个的数据的把数据全部读出来,定义了数据读取的方式,不能实现** 批量**的把数据读取出来,为此pytorch有提供了一个方法:DataLoader(),它的参数如下:
from torch.utils.data import DataLoader
dataiter = DataLoader(myDataset,batch_size=32,shuffle=True,collate_fn=default_collate)