PyTorch学习(三)—— DataLoader

一、Manual data feed

二、DataLoader

import torch
import numpy as np
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

class DiabetesDataset(Dataset):
    # Initialize your data
    def __init__(self):
        xy = np.loadtxt("./root/diabetes.csv", delimiter=",", dtype=np.float32, encoding="utf-8")
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    # return one item on the index
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    # return the data length
    def __len__(self):
        return self.len

dataset = DiabetesDataset()
train_loader = DataLoader(dataset=dataset, batch_size=32,
                          shuffle=True, num_workers=0)

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.l1 = torch.nn.Linear(8, 6)
        self.l2 = torch.nn.Linear(6, 4)
        self.l3 = torch.nn.Linear(4, 1)

        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        out1 = self.sigmoid(self.l1(x))
        out2 = self.sigmoid(self.l2(out1))
        y_pred = self.sigmoid(self.l3(out2))

        return y_pred

model = Model()

criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# training loop
for epoch in range(2):
    for i, data in enumerate(train_loader, 0):
        # get inputs
        inputs, labels = data

        # warp them in Variable
        inputs, labels = Variable(inputs), Variable(labels)

        # forward
        y_pred = model(inputs)

        # compute and print loss
        loss = criterion(y_pred, labels)
        print(epoch, i, loss.item())

        # zero gradient, perform a backward pass, and update the weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

运行结果:

0 0 0.666467010974884
D:\Anaconda3\lib\site-packages\torch\nn\_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
  warnings.warn(warning.format(ret))
0 1 0.6770526170730591
0 2 0.6472709774971008
0 3 0.680681586265564
0 4 0.6463724374771118
0 5 0.7323727607727051
0 6 0.6793930530548096
0 7 0.6668210625648499
0 8 0.6497477889060974
0 9 0.6571529507637024
0 10 0.6358109712600708
0 11 0.6288620829582214
0 12 0.6288570761680603
0 13 0.6520243883132935
0 14 0.6097450256347656
0 15 0.587654173374176
0 16 0.6603480577468872
0 17 0.7319764494895935
0 18 0.601809561252594
0 19 0.5663074254989624
0 20 0.6739665269851685
0 21 0.709747850894928
0 22 0.6790766716003418
0 23 0.6723225116729736
1 0 0.5985385179519653
1 1 0.60416579246521
1 2 0.6502715349197388
1 3 0.6705628633499146
1 4 0.6059818267822266
1 5 0.6952323317527771
1 6 0.6079085469245911
1 7 0.6134516000747681
1 8 0.6743714213371277
1 9 0.6363943815231323
1 10 0.6723730564117432
1 11 0.6787806153297424
1 12 0.6097451448440552
1 13 0.5894257426261902
1 14 0.597471296787262
1 15 0.8322824239730835
1 16 0.6011913418769836
1 17 0.6702829599380493
1 18 0.6096935868263245
1 19 0.6891161799430847
1 20 0.7401865720748901
1 21 0.6295170783996582
1 22 0.5769231915473938
1 23 0.6107428073883057

 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值