Pytorch极简入门教程(八)—— 精确率计算

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from sklearn.model_selection import train_test_split

data = pd.read_csv('dataset/HR.csv')
data.head()
data.info()
data.salary.unique()
data = data.join(pd.get_dummies(data.salary))
del data['salary']
data = data.join(pd.get_dummies(data.part))
del data['part']
data.head()
Y_data = data.left.values
print("Y_data.shape:\t", Y_data.shape)
#Y_data = data.left.values.reshape(-1, 1)
Y = torch.from_numpy(Y_data).type(torch.FloatTensor)
X_data = data[[c for c in data.columns if c !='left']].values
X = torch.from_numpy(X_data).type(torch.FloatTensor)

loss_fn = nn.BCELoss()
batch = 64
no_of_batches = len(data)//batch
epochs = 100

import torch.nn.functional as F
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.liner_1 = nn.Linear(20, 64)
        self.liner_2 = nn.Linear(64, 64)
        self.liner_3 = nn.Linear(64, 1)
    def forward(self, input):
        x = F.relu(self.liner_1(input))
        x = F.relu(self.liner_2(x))
        x = F.sigmoid(self.liner_3(x))
        return x

lr = 0.0001
def get_model():
    model = Model()
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    return model, opt

train_x, test_x, train_y, test_y = train_test_split(X_data, Y_data)

train_x = torch.from_numpy(train_x).type(torch.float32)
train_y = torch.from_numpy(train_y).type(torch.float32)
test_x = torch.from_numpy(test_x).type(torch.float32)
test_y = torch.from_numpy(test_y).type(torch.float32)

train_ds = TensorDataset(train_x, train_y)
train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True)

test_ds = TensorDataset(test_x, test_y)
test_dl = DataLoader(test_ds, batch_size=batch)

def accuracy(y_pred, y_true):
    y_pred = (y_pred > 0.5).type(torch.float32)
    acc = (y_pred == y_true).float().mean()
    return acc

model, optim = get_model()

for epoch in range(epochs):
    for x, y in train_dl:
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optim.zero_grad()
        loss.backward()
        optim.step()
    with torch.no_grad():
        epoch_accuracy = accuracy(model(train_x), train_y)
        epoch_loss = loss_fn(model(train_x), train_y).data

        epoch_test_accuracy = accuracy(model(test_x), test_y)
        epoch_test_loss = loss_fn(model(test_x), test_y).data
        print('epoch: ', epoch, 'loss: ', round(epoch_loss.item(), 3),
              'accuracy:', round(epoch_accuracy.item(), 3),
              'test_loss: ', round(epoch_test_loss.item(), 3),
              'test_accuracy:', round(epoch_test_accuracy.item(), 3)
              )

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14999 entries, 0 to 14998
Data columns (total 10 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   satisfaction_level     14999 non-null  float64
 1   last_evaluation        14999 non-null  float64
 2   number_project         14999 non-null  int64  
 3   average_montly_hours   14999 non-null  int64  
 4   time_spend_company     14999 non-null  int64  
 5   Work_accident          14999 non-null  int64  
 6   left                   14999 non-null  int64  
 7   promotion_last_5years  14999 non-null  int64  
 8   part                   14999 non-null  object 
 9   salary                 14999 non-null  object 
dtypes: float64(2), int64(6), object(2)
memory usage: 1.1+ MB
Y_data.shape:	 (14999,)
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\functional.py:1350: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\modules\loss.py:498: UserWarning: Using a target size (torch.Size([64])) that is different to the input size (torch.Size([64, 1])) is deprecated. Please ensure they have the same size.
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\modules\loss.py:498: UserWarning: Using a target size (torch.Size([49])) that is different to the input size (torch.Size([49, 1])) is deprecated. Please ensure they have the same size.
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\modules\loss.py:498: UserWarning: Using a target size (torch.Size([11249])) that is different to the input size (torch.Size([11249, 1])) is deprecated. Please ensure they have the same size.
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\modules\loss.py:498: UserWarning: Using a target size (torch.Size([3750])) that is different to the input size (torch.Size([3750, 1])) is deprecated. Please ensure they have the same size.
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
epoch:  0 loss:  0.566 accuracy: 0.759 test_loss:  0.558 test_accuracy: 0.772
epoch:  1 loss:  0.566 accuracy: 0.759 test_loss:  0.558 test_accuracy: 0.772
epoch:  2 loss:  0.566 accuracy: 0.759 test_loss:  0.559 test_accuracy: 0.772
epoch:  3 loss:  0.563 accuracy: 0.759 test_loss:  0.554 test_accuracy: 0.772
epoch:  4 loss:  0.563 accuracy: 0.759 test_loss:  0.553 test_accuracy: 0.772
epoch:  5 loss:  0.561 accuracy: 0.759 test_loss:  0.553 test_accuracy: 0.772
epoch:  6 loss:  0.557 accuracy: 0.759 test_loss:  0.547 test_accuracy: 0.772
epoch:  7 loss:  0.554 accuracy: 0.759 test_loss:  0.543 test_accuracy: 0.772
epoch:  8 loss:  0.556 accuracy: 0.759 test_loss:  0.544 test_accuracy: 0.772
epoch:  9 loss:  0.546 accuracy: 0.759 test_loss:  0.537 test_accuracy: 0.772
epoch:  10 loss:  0.554 accuracy: 0.759 test_loss:  0.546 test_accuracy: 0.772
epoch:  11 loss:  0.536 accuracy: 0.759 test_loss:  0.525 test_accuracy: 0.772
epoch:  12 loss:  0.53 accuracy: 0.759 test_loss:  0.52 test_accuracy: 0.772
epoch:  13 loss:  0.525 accuracy: 0.759 test_loss:  0.513 test_accuracy: 0.772
epoch:  14 loss:  0.521 accuracy: 0.759 test_loss:  0.511 test_accuracy: 0.772
epoch:  15 loss:  0.51 accuracy: 0.759 test_loss:  0.498 test_accuracy: 0.772
epoch:  16 loss:  0.5 accuracy: 0.759 test_loss:  0.489 test_accuracy: 0.772
epoch:  17 loss:  0.493 accuracy: 0.759 test_loss:  0.482 test_accuracy: 0.772
epoch:  18 loss:  0.482 accuracy: 0.759 test_loss:  0.471 test_accuracy: 0.771
epoch:  19 loss:  0.472 accuracy: 0.758 test_loss:  0.461 test_accuracy: 0.771
epoch:  20 loss:  0.462 accuracy: 0.757 test_loss:  0.45 test_accuracy: 0.771
epoch:  21 loss:  0.452 accuracy: 0.758 test_loss:  0.44 test_accuracy: 0.771
epoch:  22 loss:  0.442 accuracy: 0.755 test_loss:  0.43 test_accuracy: 0.769
epoch:  23 loss:  0.433 accuracy: 0.749 test_loss:  0.421 test_accuracy: 0.763
epoch:  24 loss:  0.43 accuracy: 0.697 test_loss:  0.419 test_accuracy: 0.707
epoch:  25 loss:  0.419 accuracy: 0.7 test_loss:  0.407 test_accuracy: 0.71
epoch:  26 loss:  0.422 accuracy: 0.662 test_loss:  0.412 test_accuracy: 0.676
epoch:  27 loss:  0.405 accuracy: 0.743 test_loss:  0.392 test_accuracy: 0.755
epoch:  28 loss:  0.402 accuracy: 0.661 test_loss:  0.392 test_accuracy: 0.675
epoch:  29 loss:  0.406 accuracy: 0.751 test_loss:  0.392 test_accuracy: 0.762
epoch:  30 loss:  0.385 accuracy: 0.721 test_loss:  0.373 test_accuracy: 0.732
epoch:  31 loss:  0.38 accuracy: 0.72 test_loss:  0.368 test_accuracy: 0.731
epoch:  32 loss:  0.372 accuracy: 0.693 test_loss:  0.361 test_accuracy: 0.703
epoch:  33 loss:  0.367 accuracy: 0.691 test_loss:  0.357 test_accuracy: 0.7
epoch:  34 loss:  0.363 accuracy: 0.674 test_loss:  0.353 test_accuracy: 0.686
epoch:  35 loss:  0.366 accuracy: 0.64 test_loss:  0.358 test_accuracy: 0.653
epoch:  36 loss:  0.365 accuracy: 0.634 test_loss:  0.357 test_accuracy: 0.647
epoch:  37 loss:  0.352 accuracy: 0.659 test_loss:  0.344 test_accuracy: 0.673
epoch:  38 loss:  0.352 accuracy: 0.642 test_loss:  0.345 test_accuracy: 0.656
epoch:  39 loss:  0.346 accuracy: 0.664 test_loss:  0.337 test_accuracy: 0.676
epoch:  40 loss:  0.343 accuracy: 0.669 test_loss:  0.335 test_accuracy: 0.682
epoch:  41 loss:  0.342 accuracy: 0.673 test_loss:  0.333 test_accuracy: 0.685
epoch:  42 loss:  0.34 accuracy: 0.638 test_loss:  0.334 test_accuracy: 0.651
epoch:  43 loss:  0.34 accuracy: 0.633 test_loss:  0.334 test_accuracy: 0.648
epoch:  44 loss:  0.333 accuracy: 0.653 test_loss:  0.327 test_accuracy: 0.666
epoch:  45 loss:  0.332 accuracy: 0.645 test_loss:  0.326 test_accuracy: 0.658
epoch:  46 loss:  0.332 accuracy: 0.635 test_loss:  0.327 test_accuracy: 0.649
epoch:  47 loss:  0.329 accuracy: 0.657 test_loss:  0.323 test_accuracy: 0.67
epoch:  48 loss:  0.327 accuracy: 0.657 test_loss:  0.321 test_accuracy: 0.67
epoch:  49 loss:  0.328 accuracy: 0.668 test_loss:  0.321 test_accuracy: 0.681
epoch:  50 loss:  0.325 accuracy: 0.635 test_loss:  0.32 test_accuracy: 0.649
epoch:  51 loss:  0.326 accuracy: 0.625 test_loss:  0.323 test_accuracy: 0.64
epoch:  52 loss:  0.322 accuracy: 0.655 test_loss:  0.316 test_accuracy: 0.669
epoch:  53 loss:  0.321 accuracy: 0.638 test_loss:  0.315 test_accuracy: 0.649
epoch:  54 loss:  0.318 accuracy: 0.651 test_loss:  0.313 test_accuracy: 0.664
epoch:  55 loss:  0.318 accuracy: 0.653 test_loss:  0.312 test_accuracy: 0.668
epoch:  56 loss:  0.314 accuracy: 0.643 test_loss:  0.31 test_accuracy: 0.658
epoch:  57 loss:  0.322 accuracy: 0.617 test_loss:  0.32 test_accuracy: 0.631
epoch:  58 loss:  0.313 accuracy: 0.642 test_loss:  0.308 test_accuracy: 0.655
epoch:  59 loss:  0.311 accuracy: 0.632 test_loss:  0.308 test_accuracy: 0.647
epoch:  60 loss:  0.31 accuracy: 0.631 test_loss:  0.308 test_accuracy: 0.645
epoch:  61 loss:  0.309 accuracy: 0.635 test_loss:  0.306 test_accuracy: 0.652
epoch:  62 loss:  0.31 accuracy: 0.625 test_loss:  0.308 test_accuracy: 0.64
epoch:  63 loss:  0.307 accuracy: 0.645 test_loss:  0.304 test_accuracy: 0.658
epoch:  64 loss:  0.306 accuracy: 0.635 test_loss:  0.303 test_accuracy: 0.651
epoch:  65 loss:  0.307 accuracy: 0.627 test_loss:  0.305 test_accuracy: 0.643
epoch:  66 loss:  0.304 accuracy: 0.631 test_loss:  0.302 test_accuracy: 0.644
epoch:  67 loss:  0.307 accuracy: 0.651 test_loss:  0.304 test_accuracy: 0.664
epoch:  68 loss:  0.301 accuracy: 0.639 test_loss:  0.299 test_accuracy: 0.653
epoch:  69 loss:  0.301 accuracy: 0.629 test_loss:  0.3 test_accuracy: 0.645
epoch:  70 loss:  0.3 accuracy: 0.64 test_loss:  0.298 test_accuracy: 0.657
epoch:  71 loss:  0.301 accuracy: 0.624 test_loss:  0.3 test_accuracy: 0.639
epoch:  72 loss:  0.297 accuracy: 0.631 test_loss:  0.296 test_accuracy: 0.646
epoch:  73 loss:  0.297 accuracy: 0.636 test_loss:  0.295 test_accuracy: 0.651
epoch:  74 loss:  0.299 accuracy: 0.623 test_loss:  0.299 test_accuracy: 0.638
epoch:  75 loss:  0.296 accuracy: 0.637 test_loss:  0.296 test_accuracy: 0.652
epoch:  76 loss:  0.296 accuracy: 0.646 test_loss:  0.294 test_accuracy: 0.661
epoch:  77 loss:  0.294 accuracy: 0.635 test_loss:  0.292 test_accuracy: 0.648
epoch:  78 loss:  0.294 accuracy: 0.64 test_loss:  0.293 test_accuracy: 0.653
epoch:  79 loss:  0.299 accuracy: 0.616 test_loss:  0.301 test_accuracy: 0.631
epoch:  80 loss:  0.295 accuracy: 0.622 test_loss:  0.296 test_accuracy: 0.637
epoch:  81 loss:  0.289 accuracy: 0.634 test_loss:  0.289 test_accuracy: 0.649
epoch:  82 loss:  0.292 accuracy: 0.624 test_loss:  0.292 test_accuracy: 0.639
epoch:  83 loss:  0.287 accuracy: 0.635 test_loss:  0.287 test_accuracy: 0.65
epoch:  84 loss:  0.286 accuracy: 0.63 test_loss:  0.286 test_accuracy: 0.643
epoch:  85 loss:  0.287 accuracy: 0.627 test_loss:  0.288 test_accuracy: 0.64
epoch:  86 loss:  0.291 accuracy: 0.653 test_loss:  0.29 test_accuracy: 0.667
epoch:  87 loss:  0.285 accuracy: 0.625 test_loss:  0.286 test_accuracy: 0.639
epoch:  88 loss:  0.284 accuracy: 0.631 test_loss:  0.284 test_accuracy: 0.647
epoch:  89 loss:  0.285 accuracy: 0.626 test_loss:  0.285 test_accuracy: 0.64
epoch:  90 loss:  0.292 accuracy: 0.655 test_loss:  0.291 test_accuracy: 0.67
epoch:  91 loss:  0.285 accuracy: 0.647 test_loss:  0.284 test_accuracy: 0.66
epoch:  92 loss:  0.288 accuracy: 0.617 test_loss:  0.29 test_accuracy: 0.631
epoch:  93 loss:  0.28 accuracy: 0.628 test_loss:  0.281 test_accuracy: 0.643
epoch:  94 loss:  0.279 accuracy: 0.634 test_loss:  0.279 test_accuracy: 0.647
epoch:  95 loss:  0.279 accuracy: 0.64 test_loss:  0.279 test_accuracy: 0.655
epoch:  96 loss:  0.277 accuracy: 0.633 test_loss:  0.277 test_accuracy: 0.647
epoch:  97 loss:  0.278 accuracy: 0.638 test_loss:  0.277 test_accuracy: 0.653
epoch:  98 loss:  0.282 accuracy: 0.621 test_loss:  0.284 test_accuracy: 0.636
epoch:  99 loss:  0.279 accuracy: 0.619 test_loss:  0.28 test_accuracy: 0.631
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch是一个广泛应用于深度学习的开源机器学习库,它提供了丰富的工具和接口,使得开发者可以更加便捷地构建和训练深度神经网络模型。 PyTorch极简入门教程可以通过以下几个步骤进行: 1. 安装PyTorch:首先需要在计算机中安装PyTorch库。可以通过官方网站或者使用包管理工具(如pip或conda)进行安装。安装完成后,可以在Python环境中导入PyTorch库。 2. 张量操作:PyTorch的核心是张量(Tensor),它是一个多维数组。学习如何创建、操作和使用张量是入门的关键。可以学习如何创建随机张量、更改张量形状、进行基本数学运算等。 3. 构建模型:在PyTorch中构建模型通常使用nn.Module类。可以学习如何定义自己的模型类,包括初始化函数、前向传播函数等。还可以学习如何添加层和激活函数,并了解常用的网络结构,如全连接层、卷积层等。 4. 训练模型:在PyTorch中训练模型通常需要定义损失函数和优化器。可以学习如何选择合适的损失函数,如交叉熵损失函数,以及常用的优化器,如随机梯度下降优化器。还可以学习如何使用训练数据批次来进行前向传播和反向传播,并进行参数更新。 5. 测试和评估:在训练完成后,需要对模型进行测试和评估。可以学习如何使用测试数据进行模型预测,并计算预测结果的准确率、精确率、召回率等指标。 虽然PyTorch入门教程只有300字,但这些步骤可以帮助初学者了解PyTorch的基本概念和操作。通过实践和深入学习,可以逐渐掌握更多高级功能和技巧,从而更好地应用PyTorch进行深度学习研究和应用开发。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值