DataLoader:batch_size=2,shufffle=True, num_workers=2
shufffle=True 就是把数据打乱
然后再按照batch_size的大小分组,使用2进程
python使用多进程会报错,要把训练放入
if _name_ == '_main_':
里面,代码就会通过了
代码如下:
import torch
import numpy as np
#两个帮助加载数据的工具类
from torch.utils.data import Dataset #构造数据集,支持索引
from torch.utils.data import DataLoader # 用mini_bach的一组数据以供训练
import matplotlib.pyplot as plt
# Dataset是一个抽象类,不能被实例化,只能被子类去继承
# 自己定义一个类,继承自Dataset
class DiabetesDataset(Dataset):
# init()魔法方法:文件小,读取所有的数据,直接加载到内存里
# 如果文件很大,初始化之后,定义文件列表,再用getitem()读出来
def __init__(self, filepath):
# filepath:文件路径
xy = np.loadtxt(filepath, delimiter=",", dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
# getitem()方法:实例化类之后,该对象把对应下标的数据拿出来
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
# len()方法:使用对象时,可以对数据条数进行返回
def __len__(self):
return self.len
dataset = DiabetesDataset('data/diabetes.csv')
# DataLoader是一个加载器,用来帮助我们加载数据的,可以进行对象实例化
# 知道索引,数据长度,就可以自动进行小批量的训练
# dataset:数据集对象 batch_size:小批量的容量 shuffle:数据集是否要打乱
# num_workers:读数据是否用多线程并行读取数据,一般设置4或8,不是越高越好
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Model()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 需要放在if/封装一下,否则在windows系统中会报错 main不要拼错
if __name__ == '__main__':
loss_list = []
for epoch in range(100):
for i, (inputs, labels) in enumerate(train_loader, 0):
y_pred = model(inputs)
loss = criterion(y_pred, labels)
print(epoch, i, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss.item())
plt.plot(range(100), loss_list)
plt.xlabel('Epoch')
plt.ylabel('Cost')
plt.show()
结果如下:
D:\Study\Anacodna\envs\DL\python.exe D:\Study\Code\LearnPyorch\08.py
0 0 0.7400119304656982
0 1 0.7308682203292847
0 2 0.7304829359054565
0 3 0.7034912705421448
0 4 0.75556480884552
0 5 0.780892014503479
0 6 0.7275230884552002
0 7 0.7690879106521606
0 8 0.718267023563385
0 9 0.718049943447113
0 10 0.7577503323554993
0 11 0.7638158798217773
0 12 0.7394761443138123
0 13 0.7458901405334473
0 14 0.7449460625648499
0 15 0.7145881652832031
0 16 0.7648269534111023
0 17 0.7135511636734009
0 18 0.7338627576828003
0 19 0.7539408802986145
0 20 0.7389920353889465
0 21 0.7247976064682007
0 22 0.7374752759933472
0 23 0.7196083068847656
......
98 23 0.5632990002632141
99 0 0.6828345656394958
99 1 0.6029864549636841
99 2 0.5435824990272522
99 3 0.6828681230545044
99 4 0.6831385493278503
99 5 0.5638036131858826
99 6 0.6036509871482849
99 7 0.6636937856674194
99 8 0.5834413766860962
99 9 0.663272500038147
99 10 0.6432188749313354
99 11 0.7235974073410034
99 12 0.6032260656356812
99 13 0.62349534034729
99 14 0.6428340673446655
99 15 0.6233822703361511
99 16 0.6630892157554626
99 17 0.683122456073761
99 18 0.6030187606811523
99 19 0.582496702671051
99 20 0.7442654371261597
99 21 0.6436271667480469
99 22 0.7229583263397217
99 23 0.7287114858627319
Process finished with exit code 0
图像如下
作业见主页标题为泰坦尼克号幸存者预测的那一篇文章