【PyTorch深度学习八】加载数据集

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter = ',', dtype = np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:,:-1])
        self.y_data = torch.from_numpy(xy[:,[-1]])

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

dataset = DiabetesDataset('diabetes.csv.gz')

train_loader = DataLoader(dataset=dataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=2)

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8,6)
        self.linear2 = torch.nn.Linear(6,4)
        self.linear3 = torch.nn.Linear(4,1)
        # self.activate = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

model = Model()

criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

if __name__ == '__main__':
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)

            print(epoch, i, loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

部分结果:

0 0 0.807451605796814
0 1 0.7728536128997803
0 2 0.8022035956382751
0 3 0.7452567219734192
0 4 0.737909734249115
0 5 0.730575442314148
0 6 0.7187033891677856
0 7 0.7092961668968201
0 8 0.712911069393158
0 9 0.6880234479904175
0 10 0.7077547907829285
0 11 0.701163649559021
0 12 0.697856068611145
0 13 0.6926714777946472
0 14 0.6870428919792175
0 15 0.6854082345962524
0 16 0.6911633014678955
0 17 0.6793688535690308
0 18 0.6736245155334473
0 19 0.6988233923912048
0 20 0.6675996780395508
0 21 0.6701093912124634
0 22 0.6625561714172363
0 23 0.6497127413749695
1 0 0.6436294317245483
1 1 0.6548228859901428
1 2 0.660560131072998
1 3 0.6324721574783325
1 4 0.6946964859962463
1 5 0.6256315112113953
1 6 0.6329830288887024
1 7 0.675372302532196
1 8 0.6632266640663147
1 9 0.6629412770271301
1 10 0.636751651763916
1 11 0.5821835994720459
1 12 0.6338260769844055
1 13 0.6758953332901001
1 14 0.7215102910995483
1 15 0.6919896006584167
1 16 0.5825151801109314
1 17 0.5963782668113708
1 18 0.6281926035881042
1 19 0.6950696110725403
1 20 0.6956918835639954
1 21 0.6612802743911743
1 22 0.697222888469696
1 23 0.6711336970329285
2 0 0.5891063809394836
2 1 0.5515356063842773
2 2 0.6244746446609497
2 3 0.7373515367507935
2 4 0.5667764544487
2 5 0.6427653431892395
2 6 0.6815708875656128
2 7 0.6031674146652222
2 8 0.6426284313201904
2 9 0.6022633910179138
2 10 0.6427586674690247
2 11 0.6427531242370605
2 12 0.7255752086639404
2 13 0.6427459716796875
2 14 0.6834381222724915
2 15 0.5187731385231018
2 16 0.6635192036628723
2 17 0.7055662870407104
2 18 0.6627863049507141
2 19 0.620648980140686
2 20 0.662699818611145
2 21 0.7671647071838379
2 22 0.683029294013977
2 23 0.6164393424987793
...
97 0 0.41560155153274536
97 1 0.49379831552505493
97 2 0.36065608263015747
97 3 0.5267201662063599
97 4 0.5858941674232483
97 5 0.6695444583892822
97 6 0.27213338017463684
97 7 0.5485180020332336
97 8 0.5516160130500793
97 9 0.37532785534858704
97 10 0.4675712585449219
97 11 0.23860856890678406
97 12 0.5142492055892944
97 13 0.4175170660018921
97 14 0.4006308615207672
97 15 0.5588406920433044
97 16 0.4027820825576782
97 17 0.5376626253128052
97 18 0.4327486753463745
97 19 0.46236151456832886
97 20 0.3504275977611542
97 21 0.5225668549537659
97 22 0.4279901087284088
97 23 0.5687180161476135
98 0 0.4553259611129761
98 1 0.45436468720436096
98 2 0.4701966643333435
98 3 0.48931851983070374
98 4 0.44848868250846863
98 5 0.3594033122062683
98 6 0.41738736629486084
98 7 0.5094685554504395
98 8 0.4144607186317444
98 9 0.571685791015625
98 10 0.49335598945617676
98 11 0.3502487242221832
98 12 0.4083734154701233
98 13 0.6185716390609741
98 14 0.3824211657047272
98 15 0.35667693614959717
98 16 0.7378003001213074
98 17 0.2549207806587219
98 18 0.607338547706604
98 19 0.521765947341919
98 20 0.32979416847229004
98 21 0.4575522243976593
98 22 0.4112623333930969
98 23 0.558805525302887
99 0 0.5199065804481506
99 1 0.49636274576187134
99 2 0.42843925952911377
99 3 0.4882776737213135
99 4 0.33503004908561707
99 5 0.41485053300857544
99 6 0.31385156512260437
99 7 0.3722498416900635
99 8 0.3458700180053711
99 9 0.5599054098129272
99 10 0.5374796986579895
99 11 0.40813690423965454
99 12 0.6304723024368286
99 13 0.5381195545196533
99 14 0.46311262249946594
99 15 0.4137115180492401
99 16 0.41994985938072205
99 17 0.3973977565765381
99 18 0.5244745016098022
99 19 0.5306398868560791
99 20 0.49507936835289
99 21 0.595230758190155
99 22 0.40060484409332275
99 23 0.417412668466568
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值