1、采用numpy实现梯度下降
import numpy as np
x_data = np.array([1,2,3])#训练数据
y_data = np.array([2,8,6])#训练数据
lr = 0.1 #学习率
w = 0 #初始权重
cost = [] #每次迭代的损失
for i in range(10):
y_predict = x_data * w
loss = np.average((y_predict - y_data)**2) #均方误差 误差函数为均方误差
cost.append(loss)
d_w = 2*(y_data - y_predict)@x_data.T/(x_data.shape[0]) #通过矩阵运算得到更新值
w = w + lr*d_w
print(w)
print(cost)
2.571428571424112
[34.666666666666664, 3.946666666666666, 3.810133333333333, 3.8095265185185188, 3.809523821563785, 3.8095238095773207, 3.809523809524048, 3.80952380952381, 3.809523809523809, 3.80952380952381]
end
2、采用pytorch实现梯度下降
import torch
from torch.autograd import variable
x_data = variable(torch.Tensor([[1.0],[2.0],[3.0]]))
y_data = variable(torch.Tensor([[2.0], [8.0], [6.0]]))
lr = 0.1
w=variable(torch.FloatTensor([0]), requires_grad=True)
cost = []
for i in range(10):
y_predict = x_data*w
loss = torch.mean((y_predict - y_data)**2)
cost.append(loss.data.item())
loss.backward()
w.data = w.data - lr*w.grad.data
print(w.data)
print(cost)
tensor([3.7389])
[34.66666793823242, 3.9466664791107178, 30.434133529663086, 38.31604766845703, 5.032617092132568, 25.848421096801758, 41.183990478515625, 7.145517349243164, 21.158668518066406, 43.11468505859375]
end
3、使用Pytorch nn实现
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(1,1,bias=False)
def forward(self, x):
y_predict = self.linear(x)
return y_predict
x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
y_data = variable(torch.Tensor([[2.0], [8.0], [6.0]]))
model = Model()
criterion = lambda y_pre, y:torch.sum(y_pre - y)**2 #损失函数
opt = torch.optim.SGD(model.parameters(), lr=0.01) #优化方法
cost = []
for i in range(10):
y_pre = model(x_data)
loss = criterion(y_pre, y_data)
cost.append(loss.data.item())
opt.zero_grad()
loss.backward()
opt.step()
print(list(model.parameters()))
print(cost)
[Parameter containing:
tensor([[2.5714]], requires_grad=True)]
[19.032915115356445, 3.877183675765991, 3.8098247051239014, 3.80952525138855, 3.809523344039917, 3.809523344039917, 3.809523820877075, 3.809523820877075, 3.809523820877075, 3.809523820877075]
end