import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from sklearn.model_selection import train_test_split
data = pd.read_csv('dataset/HR.csv')
data.head()
data.info()
data.salary.unique()
data = data.join(pd.get_dummies(data.salary))
del data['salary']
data = data.join(pd.get_dummies(data.part))
del data['part']
data.head()
Y_data = data.left.values
print("Y_data.shape:\t", Y_data.shape)
Y = torch.from_numpy(Y_data).type(torch.FloatTensor)
X_data = data[[c for c in data.columns if c !='left']].values
X = torch.from_numpy(X_data).type(torch.FloatTensor)
loss_fn = nn.BCELoss()
batch = 64
no_of_batches = len(data)//batch
epochs = 100
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.liner_1 = nn.Linear(20, 64)
self.liner_2 = nn.Linear(64, 64)
self.liner_3 = nn.Linear(64, 1)
def forward(self, input):
x = F.relu(self.liner_1(input))
x = F.relu(self.liner_2(x))
x = F.sigmoid(self.liner_3(x))
return x
lr = 0.0001
def get_model():
model = Model()
opt = torch.optim.Adam(model.parameters(), lr=lr)
return model, opt
train_x, test_x, train_y, test_y = train_test_split(X_data, Y_data)
train_x = torch.from_numpy(train_x).type(torch.float32)
train_y = torch.from_numpy(train_y).type(torch.float32)
test_x = torch.from_numpy(test_x).type(torch.float32)
test_y = torch.from_numpy(test_y).type(torch.float32)
train_ds = TensorDataset(train_x, train_y)
train_dl = DataLoader(train_ds, batch_size=batch, shuffle=True)
test_ds = TensorDataset(test_x, test_y)
test_dl = DataLoader(test_ds, batch_size=batch)
def accuracy(y_pred, y_true):
y_pred = (y_pred > 0.5).type(torch.float32)
acc = (y_pred == y_true).float().mean()
return acc
model, optim = get_model()
for epoch in range(epochs):
for x, y in train_dl:
y_pred = model(x)
loss = loss_fn(y_pred, y)
optim.zero_grad()
loss.backward()
optim.step()
with torch.no_grad():
epoch_accuracy = accuracy(model(train_x), train_y)
epoch_loss = loss_fn(model(train_x), train_y).data
epoch_test_accuracy = accuracy(model(test_x), test_y)
epoch_test_loss = loss_fn(model(test_x), test_y).data
print('epoch: ', epoch, 'loss: ', round(epoch_loss.item(), 3),
'accuracy:', round(epoch_accuracy.item(), 3),
'test_loss: ', round(epoch_test_loss.item(), 3),
'test_accuracy:', round(epoch_test_accuracy.item(), 3)
)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14999 entries, 0 to 14998
Data columns (total 10 columns):
--- ------ -------------- -----
0 satisfaction_level 14999 non-null float64
1 last_evaluation 14999 non-null float64
2 number_project 14999 non-null int64
3 average_montly_hours 14999 non-null int64
4 time_spend_company 14999 non-null int64
5 Work_accident 14999 non-null int64
6 left 14999 non-null int64
7 promotion_last_5years 14999 non-null int64
8 part 14999 non-null object
9 salary 14999 non-null object
dtypes: float64(2), int64(6), object(2)
memory usage: 1.1+ MB
Y_data.shape: (14999,)
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\functional.py:1350: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\modules\loss.py:498: UserWarning: Using a target size (torch.Size([64])) that is different to the input size (torch.Size([64, 1])) is deprecated. Please ensure they have the same size.
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\modules\loss.py:498: UserWarning: Using a target size (torch.Size([49])) that is different to the input size (torch.Size([49, 1])) is deprecated. Please ensure they have the same size.
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\modules\loss.py:498: UserWarning: Using a target size (torch.Size([11249])) that is different to the input size (torch.Size([11249, 1])) is deprecated. Please ensure they have the same size.
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\modules\loss.py:498: UserWarning: Using a target size (torch.Size([3750])) that is different to the input size (torch.Size([3750, 1])) is deprecated. Please ensure they have the same size.
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
epoch: 0 loss: 0.566 accuracy: 0.759 test_loss: 0.558 test_accuracy: 0.772
epoch: 1 loss: 0.566 accuracy: 0.759 test_loss: 0.558 test_accuracy: 0.772
epoch: 2 loss: 0.566 accuracy: 0.759 test_loss: 0.559 test_accuracy: 0.772
epoch: 3 loss: 0.563 accuracy: 0.759 test_loss: 0.554 test_accuracy: 0.772
epoch: 4 loss: 0.563 accuracy: 0.759 test_loss: 0.553 test_accuracy: 0.772
epoch: 5 loss: 0.561 accuracy: 0.759 test_loss: 0.553 test_accuracy: 0.772
epoch: 6 loss: 0.557 accuracy: 0.759 test_loss: 0.547 test_accuracy: 0.772
epoch: 7 loss: 0.554 accuracy: 0.759 test_loss: 0.543 test_accuracy: 0.772
epoch: 8 loss: 0.556 accuracy: 0.759 test_loss: 0.544 test_accuracy: 0.772
epoch: 9 loss: 0.546 accuracy: 0.759 test_loss: 0.537 test_accuracy: 0.772
epoch: 10 loss: 0.554 accuracy: 0.759 test_loss: 0.546 test_accuracy: 0.772
epoch: 11 loss: 0.536 accuracy: 0.759 test_loss: 0.525 test_accuracy: 0.772
epoch: 12 loss: 0.53 accuracy: 0.759 test_loss: 0.52 test_accuracy: 0.772
epoch: 13 loss: 0.525 accuracy: 0.759 test_loss: 0.513 test_accuracy: 0.772
epoch: 14 loss: 0.521 accuracy: 0.759 test_loss: 0.511 test_accuracy: 0.772
epoch: 15 loss: 0.51 accuracy: 0.759 test_loss: 0.498 test_accuracy: 0.772
epoch: 16 loss: 0.5 accuracy: 0.759 test_loss: 0.489 test_accuracy: 0.772
epoch: 17 loss: 0.493 accuracy: 0.759 test_loss: 0.482 test_accuracy: 0.772
epoch: 18 loss: 0.482 accuracy: 0.759 test_loss: 0.471 test_accuracy: 0.771
epoch: 19 loss: 0.472 accuracy: 0.758 test_loss: 0.461 test_accuracy: 0.771
epoch: 20 loss: 0.462 accuracy: 0.757 test_loss: 0.45 test_accuracy: 0.771
epoch: 21 loss: 0.452 accuracy: 0.758 test_loss: 0.44 test_accuracy: 0.771
epoch: 22 loss: 0.442 accuracy: 0.755 test_loss: 0.43 test_accuracy: 0.769
epoch: 23 loss: 0.433 accuracy: 0.749 test_loss: 0.421 test_accuracy: 0.763
epoch: 24 loss: 0.43 accuracy: 0.697 test_loss: 0.419 test_accuracy: 0.707
epoch: 25 loss: 0.419 accuracy: 0.7 test_loss: 0.407 test_accuracy: 0.71
epoch: 26 loss: 0.422 accuracy: 0.662 test_loss: 0.412 test_accuracy: 0.676
epoch: 27 loss: 0.405 accuracy: 0.743 test_loss: 0.392 test_accuracy: 0.755
epoch: 28 loss: 0.402 accuracy: 0.661 test_loss: 0.392 test_accuracy: 0.675
epoch: 29 loss: 0.406 accuracy: 0.751 test_loss: 0.392 test_accuracy: 0.762
epoch: 30 loss: 0.385 accuracy: 0.721 test_loss: 0.373 test_accuracy: 0.732
epoch: 31 loss: 0.38 accuracy: 0.72 test_loss: 0.368 test_accuracy: 0.731
epoch: 32 loss: 0.372 accuracy: 0.693 test_loss: 0.361 test_accuracy: 0.703
epoch: 33 loss: 0.367 accuracy: 0.691 test_loss: 0.357 test_accuracy: 0.7
epoch: 34 loss: 0.363 accuracy: 0.674 test_loss: 0.353 test_accuracy: 0.686
epoch: 35 loss: 0.366 accuracy: 0.64 test_loss: 0.358 test_accuracy: 0.653
epoch: 36 loss: 0.365 accuracy: 0.634 test_loss: 0.357 test_accuracy: 0.647
epoch: 37 loss: 0.352 accuracy: 0.659 test_loss: 0.344 test_accuracy: 0.673
epoch: 38 loss: 0.352 accuracy: 0.642 test_loss: 0.345 test_accuracy: 0.656
epoch: 39 loss: 0.346 accuracy: 0.664 test_loss: 0.337 test_accuracy: 0.676
epoch: 40 loss: 0.343 accuracy: 0.669 test_loss: 0.335 test_accuracy: 0.682
epoch: 41 loss: 0.342 accuracy: 0.673 test_loss: 0.333 test_accuracy: 0.685
epoch: 42 loss: 0.34 accuracy: 0.638 test_loss: 0.334 test_accuracy: 0.651
epoch: 43 loss: 0.34 accuracy: 0.633 test_loss: 0.334 test_accuracy: 0.648
epoch: 44 loss: 0.333 accuracy: 0.653 test_loss: 0.327 test_accuracy: 0.666
epoch: 45 loss: 0.332 accuracy: 0.645 test_loss: 0.326 test_accuracy: 0.658
epoch: 46 loss: 0.332 accuracy: 0.635 test_loss: 0.327 test_accuracy: 0.649
epoch: 47 loss: 0.329 accuracy: 0.657 test_loss: 0.323 test_accuracy: 0.67
epoch: 48 loss: 0.327 accuracy: 0.657 test_loss: 0.321 test_accuracy: 0.67
epoch: 49 loss: 0.328 accuracy: 0.668 test_loss: 0.321 test_accuracy: 0.681
epoch: 50 loss: 0.325 accuracy: 0.635 test_loss: 0.32 test_accuracy: 0.649
epoch: 51 loss: 0.326 accuracy: 0.625 test_loss: 0.323 test_accuracy: 0.64
epoch: 52 loss: 0.322 accuracy: 0.655 test_loss: 0.316 test_accuracy: 0.669
epoch: 53 loss: 0.321 accuracy: 0.638 test_loss: 0.315 test_accuracy: 0.649
epoch: 54 loss: 0.318 accuracy: 0.651 test_loss: 0.313 test_accuracy: 0.664
epoch: 55 loss: 0.318 accuracy: 0.653 test_loss: 0.312 test_accuracy: 0.668
epoch: 56 loss: 0.314 accuracy: 0.643 test_loss: 0.31 test_accuracy: 0.658
epoch: 57 loss: 0.322 accuracy: 0.617 test_loss: 0.32 test_accuracy: 0.631
epoch: 58 loss: 0.313 accuracy: 0.642 test_loss: 0.308 test_accuracy: 0.655
epoch: 59 loss: 0.311 accuracy: 0.632 test_loss: 0.308 test_accuracy: 0.647
epoch: 60 loss: 0.31 accuracy: 0.631 test_loss: 0.308 test_accuracy: 0.645
epoch: 61 loss: 0.309 accuracy: 0.635 test_loss: 0.306 test_accuracy: 0.652
epoch: 62 loss: 0.31 accuracy: 0.625 test_loss: 0.308 test_accuracy: 0.64
epoch: 63 loss: 0.307 accuracy: 0.645 test_loss: 0.304 test_accuracy: 0.658
epoch: 64 loss: 0.306 accuracy: 0.635 test_loss: 0.303 test_accuracy: 0.651
epoch: 65 loss: 0.307 accuracy: 0.627 test_loss: 0.305 test_accuracy: 0.643
epoch: 66 loss: 0.304 accuracy: 0.631 test_loss: 0.302 test_accuracy: 0.644
epoch: 67 loss: 0.307 accuracy: 0.651 test_loss: 0.304 test_accuracy: 0.664
epoch: 68 loss: 0.301 accuracy: 0.639 test_loss: 0.299 test_accuracy: 0.653
epoch: 69 loss: 0.301 accuracy: 0.629 test_loss: 0.3 test_accuracy: 0.645
epoch: 70 loss: 0.3 accuracy: 0.64 test_loss: 0.298 test_accuracy: 0.657
epoch: 71 loss: 0.301 accuracy: 0.624 test_loss: 0.3 test_accuracy: 0.639
epoch: 72 loss: 0.297 accuracy: 0.631 test_loss: 0.296 test_accuracy: 0.646
epoch: 73 loss: 0.297 accuracy: 0.636 test_loss: 0.295 test_accuracy: 0.651
epoch: 74 loss: 0.299 accuracy: 0.623 test_loss: 0.299 test_accuracy: 0.638
epoch: 75 loss: 0.296 accuracy: 0.637 test_loss: 0.296 test_accuracy: 0.652
epoch: 76 loss: 0.296 accuracy: 0.646 test_loss: 0.294 test_accuracy: 0.661
epoch: 77 loss: 0.294 accuracy: 0.635 test_loss: 0.292 test_accuracy: 0.648
epoch: 78 loss: 0.294 accuracy: 0.64 test_loss: 0.293 test_accuracy: 0.653
epoch: 79 loss: 0.299 accuracy: 0.616 test_loss: 0.301 test_accuracy: 0.631
epoch: 80 loss: 0.295 accuracy: 0.622 test_loss: 0.296 test_accuracy: 0.637
epoch: 81 loss: 0.289 accuracy: 0.634 test_loss: 0.289 test_accuracy: 0.649
epoch: 82 loss: 0.292 accuracy: 0.624 test_loss: 0.292 test_accuracy: 0.639
epoch: 83 loss: 0.287 accuracy: 0.635 test_loss: 0.287 test_accuracy: 0.65
epoch: 84 loss: 0.286 accuracy: 0.63 test_loss: 0.286 test_accuracy: 0.643
epoch: 85 loss: 0.287 accuracy: 0.627 test_loss: 0.288 test_accuracy: 0.64
epoch: 86 loss: 0.291 accuracy: 0.653 test_loss: 0.29 test_accuracy: 0.667
epoch: 87 loss: 0.285 accuracy: 0.625 test_loss: 0.286 test_accuracy: 0.639
epoch: 88 loss: 0.284 accuracy: 0.631 test_loss: 0.284 test_accuracy: 0.647
epoch: 89 loss: 0.285 accuracy: 0.626 test_loss: 0.285 test_accuracy: 0.64
epoch: 90 loss: 0.292 accuracy: 0.655 test_loss: 0.291 test_accuracy: 0.67
epoch: 91 loss: 0.285 accuracy: 0.647 test_loss: 0.284 test_accuracy: 0.66
epoch: 92 loss: 0.288 accuracy: 0.617 test_loss: 0.29 test_accuracy: 0.631
epoch: 93 loss: 0.28 accuracy: 0.628 test_loss: 0.281 test_accuracy: 0.643
epoch: 94 loss: 0.279 accuracy: 0.634 test_loss: 0.279 test_accuracy: 0.647
epoch: 95 loss: 0.279 accuracy: 0.64 test_loss: 0.279 test_accuracy: 0.655
epoch: 96 loss: 0.277 accuracy: 0.633 test_loss: 0.277 test_accuracy: 0.647
epoch: 97 loss: 0.278 accuracy: 0.638 test_loss: 0.277 test_accuracy: 0.653
epoch: 98 loss: 0.282 accuracy: 0.621 test_loss: 0.284 test_accuracy: 0.636
epoch: 99 loss: 0.279 accuracy: 0.619 test_loss: 0.28 test_accuracy: 0.631