一、Manual data feed
二、DataLoader
import torch
import numpy as np
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
class DiabetesDataset(Dataset):
# Initialize your data
def __init__(self):
xy = np.loadtxt("./root/diabetes.csv", delimiter=",", dtype=np.float32, encoding="utf-8")
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, 0:-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
# return one item on the index
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
# return the data length
def __len__(self):
return self.len
dataset = DiabetesDataset()
train_loader = DataLoader(dataset=dataset, batch_size=32,
shuffle=True, num_workers=0)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.l1 = torch.nn.Linear(8, 6)
self.l2 = torch.nn.Linear(6, 4)
self.l3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
out1 = self.sigmoid(self.l1(x))
out2 = self.sigmoid(self.l2(out1))
y_pred = self.sigmoid(self.l3(out2))
return y_pred
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# training loop
for epoch in range(2):
for i, data in enumerate(train_loader, 0):
# get inputs
inputs, labels = data
# warp them in Variable
inputs, labels = Variable(inputs), Variable(labels)
# forward
y_pred = model(inputs)
# compute and print loss
loss = criterion(y_pred, labels)
print(epoch, i, loss.item())
# zero gradient, perform a backward pass, and update the weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
运行结果:
0 0 0.666467010974884
D:\Anaconda3\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
warnings.warn(warning.format(ret))
0 1 0.6770526170730591
0 2 0.6472709774971008
0 3 0.680681586265564
0 4 0.6463724374771118
0 5 0.7323727607727051
0 6 0.6793930530548096
0 7 0.6668210625648499
0 8 0.6497477889060974
0 9 0.6571529507637024
0 10 0.6358109712600708
0 11 0.6288620829582214
0 12 0.6288570761680603
0 13 0.6520243883132935
0 14 0.6097450256347656
0 15 0.587654173374176
0 16 0.6603480577468872
0 17 0.7319764494895935
0 18 0.601809561252594
0 19 0.5663074254989624
0 20 0.6739665269851685
0 21 0.709747850894928
0 22 0.6790766716003418
0 23 0.6723225116729736
1 0 0.5985385179519653
1 1 0.60416579246521
1 2 0.6502715349197388
1 3 0.6705628633499146
1 4 0.6059818267822266
1 5 0.6952323317527771
1 6 0.6079085469245911
1 7 0.6134516000747681
1 8 0.6743714213371277
1 9 0.6363943815231323
1 10 0.6723730564117432
1 11 0.6787806153297424
1 12 0.6097451448440552
1 13 0.5894257426261902
1 14 0.597471296787262
1 15 0.8322824239730835
1 16 0.6011913418769836
1 17 0.6702829599380493
1 18 0.6096935868263245
1 19 0.6891161799430847
1 20 0.7401865720748901
1 21 0.6295170783996582
1 22 0.5769231915473938
1 23 0.6107428073883057