pytorch 入门学习处理多维特征输入
处理多维特征输入
import torch
import numpy as np
import torchvision
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8,6)
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.relu = torch.nn.ReLU() #torch.nn.Sigmoid()
self.sigmoid = torch.nn.Sigmoid()
def forward(self,x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x)) #使用 ReLU + Sigmoid 的结合, 最后一层嵌套 Sigmoid
return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
xy = np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32)
x_data = torch.from_numpy(xy[:,:-1]) #-1 表示最后一列不要
y_data = torch.from_numpy(xy[:,[-1]])
for epoch in range(1000):
#Forward
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print(epoch,loss.item())
#Backward
optimizer.zero_grad()
loss.backward()
#Update
optimizer.step()
#out:
0 0.7524145245552063
1 0.7462137937545776
2 0.7403224110603333
3 0.734732985496521
4 0.7294375896453857
5 0.7244138717651367
6 0.7196581363677979
7 0.7151512503623962
8 0.7108716368675232
9 0.7068101763725281
10 0.7029609680175781
11 0.6993327736854553
12 0.6958991885185242
13 0.6926288604736328
14 0.6895169019699097
15 0.6865927577018738
16 0.6838260889053345
17 0.6811893582344055
18 0.67868971824646
19 0.6763197183609009
20 0.6740730404853821
21 0.671952486038208
22 0.6699475049972534
23 0.6680536866188049
24 0.6662566661834717
25 0.6645610332489014
26 0.6629753112792969
27 0.6615006327629089
28 0.6601430177688599
29 0.6588588953018188
30 0.657647967338562
31 0.6565213799476624
32 0.6554760336875916
33 0.654509425163269
34 0.653608500957489
35 0.652764081954956
36 0.6519767642021179
37 0.6512464284896851
38 0.6505675911903381
39 0.649936318397522
40 0.6493468880653381
41 0.648797869682312
42 0.6482827663421631
43 0.6478033065795898
44 0.6473581790924072
45 0.646946132183075
46 0.6465625762939453
47 0.6462015509605408
48 0.645862340927124
49 0.6455459594726562
50 0.6452481150627136
51 0.6449684500694275
52 0.6447051763534546
53 0.6444559693336487
54 0.6442214250564575
55 0.6439983248710632
56 0.6437851786613464
57 0.6435820460319519
58 0.6433899402618408
59 0.6432061791419983
60 0.643031120300293
61 0.6428627967834473
62 0.6426998972892761
63 0.6425420641899109
64 0.6423884034156799
65 0.6422398686408997
66 0.6420952677726746
67 0.6419550776481628
68 0.6418185234069824
69 0.6416840553283691
70 0.641552209854126
71 0.6414230465888977
72 0.6412957310676575
73 0.6411715149879456
74 0.6410495638847351
75 0.6409288644790649
76 0.640809178352356
77 0.6406914591789246
78 0.6405748724937439
79 0.6404593586921692
80 0.6403442025184631
81 0.6402301788330078
82 0.6401163935661316
83 0.6400025486946106
84 0.639888346195221
85 0.6397725939750671
86 0.6396569013595581
87 0.6395405530929565
88 0.6394234299659729
89 0.6393055319786072
90 0.639186680316925
91 0.6390676498413086
92 0.6389491558074951
93 0.6388294696807861
94 0.6387082934379578
95 0.6385860443115234
96 0.6384627819061279
97 0.638337254524231
98 0.638209342956543
99 0.6380797624588013
Process finished with exit code 0