一、为什么要使用Datasets类
Datasets是pytorch的一个类,pytorch自带多种数据集,如:MINIST等数据集就是在pytorch的Datasets的库中的。
Pytorch中有工具函数torch.utils.Data.DataLoader,通过这个函数我们在准备加载数据集使用mini-batch的时候可以使用多线程并行处理,这样可以加快我们准备数据集的速度。Datasets就是构建这个工具函数的实例参数之一。
二、如何定义Datasets?
Dataset类是Pytorch中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示:
def getitem(self, index):
def len(self):
其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数
这里重点看 getitem函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。
三、实战
数据集的内容组成
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from torchvision import transforms
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.layer1=nn.Sequential(
nn.Linear(3,20),
nn.Sigmoid(),
nn.Linear(20,40),
nn.Sigmoid(),
nn.Linear(40,1)
)
def forward(self,x):
data=x
data=self.layer1(data)
return data
class MyDataset(Dataset):
def __init__(self,root,transform=None):
super(MyDataset,self).__init__()
#读取数据,整理读取的x值为一列
df=pd.read_csv(root,dtype=np.float32)
#self.data=pd.DataFrame(columns=['data','label'])
data=[] #用于获取3个x值并组合为一列
label=[] #用于获取标签值
self.data=[]
self.label=[]
for i in range(df.shape[0]):
x=df.loc[i] #type:Series
data.append([x['x1'],x['x2'],x['x3']])
label.append(x['y'])
#self.data['data']=data
#self.data['label']=label
self.data=data
self.label=label
self.transform=transform
def __len__(self):
return len(self.data)
def __getitem__(self, item):
x=self.data[item]
label=self.label[item]
if self.transform is not None:
x=self.transform(x)
return x,label
class ToTensor(object):
def __call__(self, seq):
#print(seq.shape)
return torch.tensor(seq,dtype=torch.float)
if __name__=='__main__':
path = 'C:/Users/Mr.Li\Desktop/test project/train.csv'
set=MyDataset(path,ToTensor())
data=torch.utils.data.DataLoader(dataset=set,batch_size=6,shuffle=True)
model=Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for epoch in range(100):
for i,( x,label) in enumerate(data):
y=model(x)
z=label.view(-1,1)
loss = loss_func(y, z)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss)