from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
data = pd.read_csv("dataset/HR.csv")
print("data.head():\t", data.head())
data.info()
print("data.info():\t", data.info())
data.part.unique()
print("data.part.unique()", data.part.unique())
data.salary.unique()
print("data.salary.unique()", data.salary.unique())
data.groupby(["salary", "part"]).size()
print("data.groupby(['salary', 'part']).size()", data.groupby(['salary', 'part']).size())
pd.get_dummies(data.salary)
print("pd.get_dummies(data.salary):\t", pd.get_dummies(data.salary))
data = data.join(pd.get_dummies(data.salary))
print("data.head()):\t", data.head())
data = data.join(pd.get_dummies(data.part))
del data["salary"]
del data["part"]
data.head()
print("data.head():\t", data.head())
data.left.value_counts()
print("data.left.value_counts():\n", data.left.value_counts())
Y_data = data.left.values.reshape(-1, 1)
print("Y_data.shape:\t", Y_data.shape)
Y = torch.from_numpy(Y_data).type(torch.float32)
print("Y.shape", Y.shape)
"""
M = [c for c in data.columns if c!= "left"]
print("M:\t", X_data)
"""
X_data = data[[c for c in data.columns if c != 'left']].values
"""
两种方式进行数据类型转换
如果numpy上转换 则用.astype(np.float32)
如果torch上转换 则用.type(torch.float32)
"""
X = torch.from_numpy(X_data.astype(np.float32))
print("X:\t", X)
print("X.size():\t", X.shape)
"""""""""""""""""""""""""""""""""""""""""""""""
创建模型:
from torch import nn
自定义模型:
nn.Module: 继承这个类
__init__:初始化所有的层
forward: 定义模型的运算过程 (前向传播的过程)
"""""""""""""""""""""""""""""""""""""""""""""""
"""
# 自定义类 方法一
class Model(nn.Module):
def __init__(self):
super().__init__()
self.liner_1 = nn.Linear(20, 64)
self.liner_2 = nn.Linear(64, 64)
self.liner_3 = nn.Linear(64, 1)
self.relu = nn.ReLU() # 初始化relu
self.sigmoid = nn.Sigmoid() # 初始化sigmoid
def forward(self, input):
x = self.Liner_1(input)
x = self.relu(x)
x = self.Liner_2(x)
x = self.rele(x)
x = self.Liner_3(x)
x = self.sigmod(x)
return x
"""
"""""""""""""""""""""""""""""""""""
方法的改写: 方法二
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.liner_1 = nn.Linear(20, 64)
self.liner_2 = nn.Linear(64, 64)
self.liner_3 = nn.Linear(64, 1)
def forward(self, input):
x = F.relu(self.Liner_1(input))
x = F.relu(self.Liner_2(x))
x = F.sigmoid(self.Liner_3(x))
return x
"""""""""""""""""""""""""""""""""""
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.Liner_1 = nn.Linear(20, 64)
self.Liner_2 = nn.Linear(64, 64)
self.Liner_3 = nn.Linear(64, 1)
def forward(self, input):
x = F.relu(self.Liner_1(input))
x = F.relu(self.Liner_2(x))
x = F.sigmoid(self.Liner_3(x))
return x
"""
model = Model() # 模型的实例化
print("model:\t", model)
"""
lr = 0.001
def get_model():
model = Model()
opt = torch.optim.Adam(model.parameters(), lr=lr)
return model, opt
model, optim = get_model()
"""
定义损失函数
"""
loss_fn = nn.BCELoss()
batch = 64
no_of_batch = len(data) // batch
epochs = 100
"""
#使用dataset类进行重构
from torch.utils.data import TensorDataset
HRdataset = TensorDataset(X, Y)
# TensorDataset()函数下面有两种魔术方法
# 第一种len()查看长度
print("len(HRdataset):\t", len(HRdataset))
# 第二种__getitem__ 对类进行索引 切片 例如:
print("HRdataset[2:5]:\t", HRdataset[2:5])
"""
HRdataset = TensorDataset(X, Y)
for epoch in range(epochs):
for i in range(no_of_batch):
x, y = HRdataset[i*batch:i*batch+batch]
y_pred=model(x)
loss =loss_fn(y_pred, y)
optim.zero_grad()
loss.backward()
optim.step()
with torch.no_grad():
print("epoch:", epoch, "loss:", loss_fn(model(X), Y).data.item())
loss_fn(model(X), Y)
print("loss_fn(model(X), Y):\t", loss_fn(model(X), Y))
"""
dataloader类
"""
HR_ds = TensorDataset(X, Y)
HT_dl = DataLoader(HR_ds, batch_size=batch, shuffle=True)
mode1, optim = get_model()
for epoch in range(epochs):
for x, y in HT_dl:
y_pred = model(x)
loss = loss_fn(y_pred, y)
optim.zero_grad()
loss.backward()
optim.step()
with torch.no_grad():
print("epoch", epoch, "loss:", loss_fn(model(x), y).data.item())
data.head(): satisfaction_level last_evaluation ... part salary
0 0.38 0.53 ... sales low
1 0.80 0.86 ... sales medium
2 0.11 0.88 ... sales medium
3 0.72 0.87 ... sales low
4 0.37 0.52 ... sales low
[5 rows x 10 columns]
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14999 entries, 0 to 14998
Data columns (total 10 columns):
--- ------ -------------- -----
0 satisfaction_level 14999 non-null float64
1 last_evaluation 14999 non-null float64
2 number_project 14999 non-null int64
3 average_montly_hours 14999 non-null int64
4 time_spend_company 14999 non-null int64
5 Work_accident 14999 non-null int64
6 left 14999 non-null int64
7 promotion_last_5years 14999 non-null int64
8 part 14999 non-null object
9 salary 14999 non-null object
dtypes: float64(2), int64(6), object(2)
memory usage: 1.1+ MB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14999 entries, 0 to 14998
Data columns (total 10 columns):
--- ------ -------------- -----
0 satisfaction_level 14999 non-null float64
1 last_evaluation 14999 non-null float64
2 number_project 14999 non-null int64
3 average_montly_hours 14999 non-null int64
4 time_spend_company 14999 non-null int64
5 Work_accident 14999 non-null int64
6 left 14999 non-null int64
7 promotion_last_5years 14999 non-null int64
8 part 14999 non-null object
9 salary 14999 non-null object
dtypes: float64(2), int64(6), object(2)
memory usage: 1.1+ MB
data.info(): None
data.part.unique() ['sales' 'accounting' 'hr' 'technical' 'support' 'management' 'IT'
'product_mng' 'marketing' 'RandD']
data.salary.unique() ['low' 'medium' 'high']
data.groupby(['salary', 'part']).size() salary part
high IT 83
RandD 51
accounting 74
hr 45
management 225
marketing 80
product_mng 68
sales 269
support 141
technical 201
low IT 609
RandD 364
accounting 358
hr 335
management 180
marketing 402
product_mng 451
sales 2099
support 1146
technical 1372
medium IT 535
RandD 372
accounting 335
hr 359
management 225
marketing 376
product_mng 383
sales 1772
support 942
technical 1147
dtype: int64
pd.get_dummies(data.salary): high low medium
0 0 1 0
1 0 0 1
2 0 0 1
3 0 1 0
4 0 1 0
... ... ... ...
14994 0 1 0
14995 0 1 0
14996 0 1 0
14997 0 1 0
14998 0 1 0
[14999 rows x 3 columns]
data.head()): satisfaction_level last_evaluation number_project ... high low medium
0 0.38 0.53 2 ... 0 1 0
1 0.80 0.86 5 ... 0 0 1
2 0.11 0.88 7 ... 0 0 1
3 0.72 0.87 5 ... 0 1 0
4 0.37 0.52 2 ... 0 1 0
[5 rows x 13 columns]
data.head(): satisfaction_level last_evaluation ... support technical
0 0.38 0.53 ... 0 0
1 0.80 0.86 ... 0 0
2 0.11 0.88 ... 0 0
3 0.72 0.87 ... 0 0
4 0.37 0.52 ... 0 0
[5 rows x 21 columns]
data.left.value_counts():
0 11428
1 3571
Name: left, dtype: int64
Y_data.shape: (14999, 1)
Y.shape torch.Size([14999, 1])
X: tensor([[0.3800, 0.5300, 2.0000, ..., 1.0000, 0.0000, 0.0000],
[0.8000, 0.8600, 5.0000, ..., 1.0000, 0.0000, 0.0000],
[0.1100, 0.8800, 7.0000, ..., 1.0000, 0.0000, 0.0000],
...,
[0.3700, 0.5300, 2.0000, ..., 0.0000, 1.0000, 0.0000],
[0.1100, 0.9600, 6.0000, ..., 0.0000, 1.0000, 0.0000],
[0.3700, 0.5200, 2.0000, ..., 0.0000, 1.0000, 0.0000]])
X.size(): torch.Size([14999, 20])
E:\Professional Software\Anconda\envs\pytracking\lib\site-packages\torch\nn\functional.py:1350: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
epoch: 0 loss: 2.5677719116210938
epoch: 1 loss: 0.6458866596221924
epoch: 2 loss: 0.6800997257232666
epoch: 3 loss: 0.65093994140625
epoch: 4 loss: 0.62894207239151
epoch: 5 loss: 0.6119461059570312
epoch: 6 loss: 0.5984881520271301
epoch: 7 loss: 0.5879918932914734
epoch: 8 loss: 0.5798552632331848
epoch: 9 loss: 0.5732376575469971
epoch: 10 loss: 0.568317711353302
epoch: 11 loss: 0.564129650592804
epoch: 12 loss: 0.5608596801757812
epoch: 13 loss: 0.5583353638648987
epoch: 14 loss: 0.5562677383422852
epoch: 15 loss: 0.5548015236854553
epoch: 16 loss: 0.5534111857414246
epoch: 17 loss: 0.5524439215660095
epoch: 18 loss: 0.5516253113746643
epoch: 19 loss: 0.5509838461875916
epoch: 20 loss: 0.5504893660545349
epoch: 21 loss: 0.5501244068145752
epoch: 22 loss: 0.5497907400131226
epoch: 23 loss: 0.5495416522026062
epoch: 24 loss: 0.5494086146354675
epoch: 25 loss: 0.5491578578948975
epoch: 26 loss: 0.5490513443946838
epoch: 27 loss: 0.5490929484367371
epoch: 28 loss: 0.549022912979126
epoch: 29 loss: 0.5488981604576111
epoch: 30 loss: 0.5489230751991272
epoch: 31 loss: 0.5488610863685608
epoch: 32 loss: 0.5489944219589233
epoch: 33 loss: 0.5487973690032959
epoch: 34 loss: 0.548848569393158
epoch: 35 loss: 0.5488126277923584
epoch: 36 loss: 0.5488066673278809
epoch: 37 loss: 0.5489434599876404
epoch: 38 loss: 0.5487702488899231
epoch: 39 loss: 0.5487971305847168
epoch: 40 loss: 0.5488684177398682
epoch: 41 loss: 0.5488698482513428
epoch: 42 loss: 0.548896074295044
epoch: 43 loss: 0.5488118529319763
epoch: 44 loss: 0.5488182902336121
epoch: 45 loss: 0.5489262342453003
epoch: 46 loss: 0.5488597750663757
epoch: 47 loss: 0.5490107536315918
epoch: 48 loss: 0.5489205121994019
epoch: 49 loss: 0.5489976406097412
epoch: 50 loss: 0.5488433241844177
epoch: 51 loss: 0.5490217208862305
epoch: 52 loss: 0.5488499999046326
epoch: 53 loss: 0.5488947033882141
epoch: 54 loss: 0.5489994287490845
epoch: 55 loss: 0.549033522605896
epoch: 56 loss: 0.5488603115081787
epoch: 57 loss: 0.5488812327384949
epoch: 58 loss: 0.5489685535430908
epoch: 59 loss: 0.5490075349807739
epoch: 60 loss: 0.5490177869796753
epoch: 61 loss: 0.5490460395812988
epoch: 62 loss: 0.5488992929458618
epoch: 63 loss: 0.54892897605896
epoch: 64 loss: 0.548870325088501
epoch: 65 loss: 0.5489423871040344
epoch: 66 loss: 0.5489640235900879
epoch: 67 loss: 0.5489773154258728
epoch: 68 loss: 0.5489982962608337
epoch: 69 loss: 0.5490091443061829
epoch: 70 loss: 0.5489961504936218
epoch: 71 loss: 0.5490107536315918
epoch: 72 loss: 0.549019455909729
epoch: 73 loss: 0.5490281581878662
epoch: 74 loss: 0.5490407943725586
epoch: 75 loss: 0.549043595790863
epoch: 76 loss: 0.5490519404411316
epoch: 77 loss: 0.5488772392272949
epoch: 78 loss: 0.548883855342865
epoch: 79 loss: 0.5488863587379456
epoch: 80 loss: 0.5488969087600708
epoch: 81 loss: 0.5488990545272827
epoch: 82 loss: 0.5489053726196289
epoch: 83 loss: 0.5489070415496826
epoch: 84 loss: 0.5489094853401184
epoch: 85 loss: 0.5489101409912109
epoch: 86 loss: 0.5489121079444885
epoch: 87 loss: 0.5489179491996765
epoch: 88 loss: 0.5489185452461243
epoch: 89 loss: 0.5489200949668884
epoch: 90 loss: 0.5489203333854675
epoch: 91 loss: 0.548920750617981
epoch: 92 loss: 0.5489305853843689
epoch: 93 loss: 0.5489307641983032
epoch: 94 loss: 0.5489311218261719
epoch: 95 loss: 0.5489312410354614
epoch: 96 loss: 0.5489327311515808
epoch: 97 loss: 0.5488605499267578
epoch: 98 loss: 0.5488608479499817
epoch: 99 loss: 0.5488609671592712
loss_fn(model(X), Y): tensor(0.5489, grad_fn=<BinaryCrossEntropyBackward>)
epoch 0 loss: 0.7845770716667175
epoch 1 loss: 0.7845770716667175
epoch 2 loss: 0.6283407807350159
epoch 3 loss: 0.47210463881492615
epoch 4 loss: 0.524183452129364
epoch 5 loss: 0.47210463881492615
epoch 6 loss: 0.6283408999443054
epoch 7 loss: 0.6804195642471313
epoch 8 loss: 0.4200258255004883
epoch 9 loss: 0.8366557955741882
epoch 10 loss: 0.5762620568275452
epoch 11 loss: 0.42002594470977783
epoch 12 loss: 0.36794713139533997
epoch 13 loss: 0.5241833925247192
epoch 14 loss: 0.47210463881492615
epoch 15 loss: 0.5762621760368347
epoch 16 loss: 0.6804195642471313
epoch 17 loss: 0.524183452129364
epoch 18 loss: 0.5241833329200745
epoch 19 loss: 0.6283408999443054
epoch 20 loss: 0.5241833925247192
epoch 21 loss: 0.8366557955741882
epoch 22 loss: 0.5241833925247192
epoch 23 loss: 0.5762621164321899
epoch 24 loss: 0.6283408999443054
epoch 25 loss: 0.7845770716667175
epoch 26 loss: 0.6283407807350159
epoch 27 loss: 0.5241833925247192
epoch 28 loss: 0.524183452129364
epoch 29 loss: 0.7324983477592468
epoch 30 loss: 0.42002591490745544
epoch 31 loss: 0.524183452129364
epoch 32 loss: 0.7845770716667175
epoch 33 loss: 0.5762621164321899
epoch 34 loss: 0.6804195642471313
epoch 35 loss: 0.7324983477592468
epoch 36 loss: 0.5762621760368347
epoch 37 loss: 0.6283408999443054
epoch 38 loss: 0.4721047878265381
epoch 39 loss: 0.36794713139533997
epoch 40 loss: 0.524183452129364
epoch 41 loss: 0.5762622356414795
epoch 42 loss: 0.524183452129364
epoch 43 loss: 0.47210463881492615
epoch 44 loss: 0.524183452129364
epoch 45 loss: 0.7324981689453125
epoch 46 loss: 0.47210460901260376
epoch 47 loss: 0.6804197430610657
epoch 48 loss: 0.6804195642471313
epoch 49 loss: 0.47210463881492615
epoch 50 loss: 0.5241833329200745
epoch 51 loss: 0.5762621760368347
epoch 52 loss: 0.42002591490745544
epoch 53 loss: 0.7845770716667175
epoch 54 loss: 0.7324983477592468
epoch 55 loss: 0.42002594470977783
epoch 56 loss: 0.36794713139533997
epoch 57 loss: 0.47210460901260376
epoch 58 loss: 0.5241833329200745
epoch 59 loss: 0.5762621164321899
epoch 60 loss: 0.5762621760368347
epoch 61 loss: 0.6804195642471313
epoch 62 loss: 0.6283408999443054
epoch 63 loss: 0.4721046984195709
epoch 64 loss: 0.5241833925247192
epoch 65 loss: 0.4721047878265381
epoch 66 loss: 0.5762621164321899
epoch 67 loss: 0.36794713139533997
epoch 68 loss: 0.42002585530281067
epoch 69 loss: 0.36794716119766235
epoch 70 loss: 0.4721046984195709
epoch 71 loss: 0.6283408999443054
epoch 72 loss: 0.5241833329200745
epoch 73 loss: 0.47210460901260376
epoch 74 loss: 0.42002585530281067
epoch 75 loss: 0.6804195642471313
epoch 76 loss: 0.6283408999443054
epoch 77 loss: 0.42002585530281067
epoch 78 loss: 0.6283408999443054
epoch 79 loss: 0.6283408999443054
epoch 80 loss: 0.5241833925247192
epoch 81 loss: 0.5762621164321899
epoch 82 loss: 0.6283408999443054
epoch 83 loss: 0.5762621164321899
epoch 84 loss: 0.42002585530281067
epoch 85 loss: 0.5762622952461243
epoch 86 loss: 0.5241833925247192
epoch 87 loss: 0.5241833925247192
epoch 88 loss: 0.6804196238517761
epoch 89 loss: 0.5241833925247192
epoch 90 loss: 0.5762621760368347
epoch 91 loss: 0.5762621760368347
epoch 92 loss: 0.4721046984195709
epoch 93 loss: 0.47210463881492615
epoch 94 loss: 0.5762622356414795
epoch 95 loss: 0.4721047878265381
epoch 96 loss: 0.5762621760368347
epoch 97 loss: 0.4721046984195709
epoch 98 loss: 0.31586840748786926
epoch 99 loss: 0.42002594470977783