需要定义DiabetesDataset做为加载数据集diabetes的类, 继承自Dataset,Dataset是抽象类,需要实现其中的三个方法,__ init , getitem , len __
import torch
from torch.utils.data import Dataset # 抽象类
from torch.utils.data import DataLoader
import numpy as np
class DiabetesDataset(Dataset): # 继承自Dataset
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter = ',', dtype = np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index): # 支持下标操作,根据索引获取数据
return self.x_data[index], self.y_data[index]
def __len__(self): # 获取数据条数
return self.len
dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset = dataset,# 处理的数据集
batch_size = 32, # 每次处理的数据大小
shuffle = True, # 是否打乱
num_workers = 0) # 多线程数量,在windows里需要设置为0, Linux可以大于0
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid() # 与nn.Function.sigmoid不同,用于构建计算图
def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Model()
criterion = torch.nn.BCELoss(reduction='mean') # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1) # 优化器
if __name__ == '__main__':
for epoch in range(100):
for i, data in enumerate (train_loader, 0):
#1. Prepare data
inputs, labels = data
# 2.Forward
y_pred = model(inputs)
loss = criterion(y_pred, labels)
print(epoch, i, loss.item())
# 3.BackWard
optimizer.zero_grad()
loss.backward()
# 4.Update
optimizer.step()
输出:
0 0 0.6936783194541931
0 1 0.693471372127533
0 2 0.6917673349380493
0 3 0.6861389875411987
0 4 0.6913132667541504
0 5 0.6789288520812988
0 6 0.6768878698348999
0 7 0.6651645302772522
0 8 0.6861144304275513
0 9 0.6686166524887085
0 10 0.6661809682846069
0 11 0.6636384129524231
0 12 0.6618748307228088
0 13 0.6681938767433167
0 14 0.6153277158737183
0 15 0.6548603773117065
... ...
98 14 0.5910221338272095
98 15 0.6699521541595459
98 16 0.6283824443817139
98 17 0.6495291590690613
98 18 0.6865949630737305
98 19 0.6016601920127869
98 20 0.630635678768158
98 21 0.6044492721557617
98 22 0.6302173137664795
98 23 0.6102578043937683
99 0 0.5284566283226013
99 1 0.6872431039810181
99 2 0.6330350041389465
99 3 0.6103817820549011
99 4 0.6251040697097778
99 5 0.6059320569038391
99 6 0.6281994581222534
99 7 0.6733802556991577
99 8 0.6273549795150757
99 9 0.7067252993583679
99 10 0.6479067802429199
99 11 0.7034580111503601
99 12 0.633543848991394
99 13 0.5920330882072449
99 14 0.6311102509498596
99 15 0.6479007601737976
99 16 0.6280706524848938
99 17 0.6995146870613098
99 18 0.6469420790672302
99 19 0.6414950489997864
99 20 0.5969923734664917
99 21 0.5866757035255432
99 22 0.5923041105270386
99 23 0.524055004119873