参考错错莫课代表的PyTorch *深度学习实践 第8讲*https://blog.csdn.net/bit452/article/details/109686474
笔记:
需要mini_batch 就需要import DataSet和DataLoader;
Dataset是一个抽象函数,不能直接实例化,所以我们要创建一个自己类,继承Dataset;
继承Dataset后我们必须实现三个魔法函数:
__init__()是初始化函数,之后我们可以提供数据集路径进行数据的加载
__getitem__()帮助我们通过索引找到某个样本
__len__()帮助我们返回数据集大小
代码如下:
import torch import numpy as np from torch.utils.data import Dataset from torch.utils.data import DataLoader # prepare dataset class DiabetesDataset(Dataset): def __init__(self, filepath): xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32) self.len = xy.shape[0] # shape本身是一个二元组(x,y)对应数据集的行数和列数,这里[0]我们取行数,即样本数 self.x_data = torch.from_numpy(xy[:, :-1]) # 第一个‘:’是指读取所有行,第二个‘:’是指从第一列开始,最后一列不要 self.y_data = torch.from_numpy(xy[:, [-1]]) # [-1] 最后得到的是个矩阵 def __getitem__(self, index): return self.x_data[index],self.y_data[index] def __len__(self): return self.len dataset = DiabetesDataset('diabetes.csv') train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2) # design model using class class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() self.linear1 = torch.nn.Linear(8, 6) self.linear2 = torch.nn.Linear(6, 4) self.linear3 = torch.nn.Linear(4, 1) self.activate = torch.nn.ReLU() # 将其看作是网络的激活层,不单是函数 self.sigmoid = torch.nn.Sigmoid() # 便于调试不同激活函数 def forward(self, x): x = self.activate(self.linear1(x)) x = self.activate(self.linear2(x)) x = self.sigmoid(self.linear3(x)) return x model = Model() # construct loss and optimizer # 默认情况下,loss会基于element平均,如果size_average=False的话,loss会被累加;自己可以试一试False和mean criterion = torch.nn.BCELoss(size_average='mean') optimizer = torch.optim.SGD(model.parameters(), lr=0.1) # loss_list=[] # epoch_list=[] # training cycle forward, backward, update if __name__=='__main__': # if这条语句在windows系统下一定要加,否则会报错,视频里面有讲解 for epoch in range(100): for i,data in enumerate(train_loader,0): inputs,labels = data # 将输入的数据赋给inputs,将一个batch的标签传给labels y_pred = model(inputs) loss = criterion(y_pred, labels) print(epoch, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step()
部分输出结果:
0 0.707435131072998
0 0.7413188815116882
0 0.7382362484931946
0 0.7657931447029114
0 0.7216707468032837
0 0.731844425201416
0 0.7010486721992493
0 0.7025167346000671
0 0.7090010046958923
0 0.6981418132781982
0 0.6918289661407471
0 0.6823720932006836
0 0.7050893306732178
0 0.6856189966201782
0 0.6769058108329773
0 0.6759095788002014
0 0.6752151846885681
0 0.7018520832061768
0 0.6617619395256042
0 0.6989297270774841
0 0.6508258581161499
0 0.6442074179649353
0 0.6777867078781128
0 0.6061913371086121...
99 0.3661899268627167
99 0.34995681047439575
99 0.6093645095825195
99 0.39561158418655396
99 0.3813444972038269
99 0.3222960829734802
99 0.30144819617271423
99 0.34436652064323425
99 0.5827363133430481
99 0.546252965927124
99 0.3984104096889496
99 0.38216859102249146
99 0.4403568506240845
99 0.34911349415779114
99 0.4326655864715576
99 0.48392561078071594
99 0.4723842442035675
99 0.5359323024749756
99 0.5373116135597229
99 0.4402441084384918
99 0.47346335649490356
99 0.5459398627281189
99 0.30009403824806213