Epoch:所有样本都进行一次forward和backward。
Batch_size:训练样本中一次forward和backward的样本数。
Iteration:内迭代的次数,即训练集中有多少个batch_size,总样本数N/Batch_size。
Dataloader中的batch和shuffle:
示例代码:(还是diabetes.csv.gz数据集)
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
class DiabetesDataset(Dataset):
def __init__(self,filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
#可以按索引取数据
return self.x_data[index],self.y_data[index] #返回的是(x_data,y_data)元组
def __len__(self):
#可以获得数据集长度
return self.len
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self,x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Model()
filepath = 'diabetes.csv.gz'
dataset = DiabetesDataset(filepath)
train_loader = DataLoader(dataset=dataset,batch_size=22,shuffle=True,num_workers=2)
# print(train_loader)
critarion = torch.nn.BCELoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(),lr=0.01) #将模型参数作为优化算法的参数
if __name__ == '__main__':
epoch_list = []
loss_list = []
for epoch in range(20):
epoch_list.append(epoch)
for i, data in enumerate(train_loader,0):
#准备数据,enumerate()用来列出数据和数据下标
input, labels = data #是张量
#计算损失
y_pred = model(input)
loss = critarion(y_pred, labels)
print(epoch,i,loss.item()) #epoch是迭代次数,i即iteration是内迭代次数,就是总样本数/batch_size
loss_list.append(loss.item())
#反馈
optimizer.zero_grad()
loss.backward()
#更新
optimizer.step()
# plt.plot(epoch_list,loss_list)
# plt.xlabel('epoch')
# plt.ylabel('loss')
# plt.show()
说明:
1.通过Dataset类构造带有索引的数据集,它是一个抽象类,只能被其他子类继承,不能被实例化,要调用的话就要继承Dataset来产生一个可实例化的自定义的类,即代码中的class DiabetesDataset(Dataset),重写__getitem__()函数实现按索引取数据,返回的是(x_data,y_data)元组,重写__len__()函数获得数据集长度。
2.DataLoader是将数据集分成Mini-Batch的类。帮助加载数据集
3.num_workers:用几个线程来并行计算
4.enumerate()函数:列出数据及数据下标