前言
b站刘洪普老师的pytorch入门课笔记。记录学习。
本文内容为多维数据的计算过程示例。
计算过程
首先,
所以,
其中,
故,
Mini-Batch
其中,
所以,
代码
我的数据是在
F:\Anaconda3_5.3.1\Lib\site-packages\sklearn\datasets\data\diabetes_data.csv.gz
其中gz(gzip)文件是linux下的压缩格式,可以采用常用的7z解压
注1:神经网络本质上是寻找一种非线性空间变换函数。通过在多次的线性变换中添加激活函数(非线性因子)来实现这个过程。
注2:多层的网络有一点缺陷,即学习能力太强会将噪声也学到,所以也不能设置太多层,要考虑泛化能力。
import torch
import numpy as np
#数据读入
x = np.loadtxt('F:\Anaconda3_5.3.1\Lib\site-packages\sklearn\datasets\data\diabetes_data.csv.gz',delimiter=' ',dtype=np.float32) #delimiter 指的是分隔符的意思
y = np.loadtxt('F:\Anaconda3_5.3.1\Lib\site-packages\sklearn\datasets\data\diabetes_target.csv.gz',delimiter=' ',dtype=np.float32)
x_data = torch.from_numpy(x[:,:]) #x_data,[442*10]
y_data = torch.from_numpy(y[:])# y_data,[442]
y_data = y_data.reshape((442,1))# y_data,[442*1]
class Model(torch.nn.Module):
def __init__(self):
super(Model,self).__init__()
self.linear = torch.nn.Linear(10,1) #直接10维输入,1维输出的二分类模型 对应的w为10*1的向量。
#self.linear1 = torch.nn.Linear(10,8)
#self.linear2 = torch.nn.Linear(8,4)
#self.linear3 = torch.nn.Linear(4,1) #从10->8->4->1维
self.sigmoid = torch.nn.Sigmoid() #这里的激活函数可以选择很多,也可以用Sign()、Linear()、ReLU()等。 要改成:self.activate = torch.nn.ReLU().
#self.activate = torch.nn.ReLU()#改成activate,会导致损失保持一个值不变
def forward(self, x):
x = self.sigmoid(self.linear(x)) #直接输入10->1
#x = self.activate(self.linear1(x))
#x = self.activate(self.linear2(x))
#x = self.activate(self.linear3(x))
return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
for epoch in range(100):
#Forward
y_pred = model(x_data) #
loss = criterion(y_pred,y_data)
print(epoch,loss.item())
#Backward
optimizer.zero_grad()
loss.backward()
#update
optimizer.step()
输出
0 46.6182746887207
1 -183.4557342529297
2 -412.6076354980469
3 -641.358154296875
4 -869.9998779296875
5 -1098.6163330078125
6 -1327.2020263671875
7 -1555.8177490234375
8 -1784.475341796875
9 -2013.0660400390625
10 -2243.395263671875
11 -2409.427490234375
12 -4175.97216796875
13 -4175.97216796875
14 -4175.97216796875
15 -4175.97216796875
16 -4175.97216796875
17 -4175.97216796875
18 -4175.97216796875
19 -4175.97216796875
20 -4175.97216796875
21 -4175.97216796875
22 -4175.97216796875
23 -4175.97216796875
24 -4175.97216796875
25 -4175.97216796875
26 -4175.97216796875
27 -4175.97216796875
28 -4175.97216796875
29 -4175.97216796875
30 -4175.97216796875
31 -4175.97216796875
32 -4175.97216796875
33 -4175.97216796875
34 -4175.97216796875
35 -4175.97216796875
36 -4175.97216796875
37 -4175.97216796875
38 -4175.97216796875
39 -4175.97216796875
40 -4175.97216796875
41 -4175.97216796875
42 -4175.97216796875
43 -4175.97216796875
44 -4175.97216796875
45 -4175.97216796875
46 -4175.97216796875
47 -4175.97216796875
48 -4175.97216796875
49 -4175.97216796875
总结
不太能理解损失的结果。。,可能是计算过程中有哪步有问题?整体的计算过程就是这样了,仅供参考~