import torch
from torch.autograd import Variable
import numpy as np
import time
use_cuda = True
def get_data(): # 生成拟合数据
train_x = np.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]) #[10]
train_y = np.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])
dtype = torch.FloatTensor
if torch.cuda.is_available() and use_cuda:
X = Variable(torch.from_numpy(train_x).type(dtype).to('cuda:0'), requires_grad=False).view(20,1)
y = Variable(torch.from_numpy(train_y).type(dtype).to('cuda:0'), requires_grad=False)
else:
X = Variable(torch.from_numpy(train_x).type(dtype), requires_grad=False).view(20,1)
y = Variable(torch.from_numpy(train_y).type(dtype), requires_grad=False)
return X, y
def get_weights(): # 生成权重文件
if torch.cuda.is_available() and use_cuda:
w = Variable(torch.randn(1, device=torch.device('cuda:0')), requires_grad = True)
b = Variable(torch.randn(1, device=torch.device('cuda:0')), requires_grad = True)
else:
w = Variable(torch.randn(1), requires_grad = True)
b = Variable(torch.randn(1), requires_grad = True)
return w,b
def simple_network(x, w, b): # 拟合网络
y_pred = torch.add(torch.matmul(x,w), b)
return y_pred
def loss_fn(y, y_pred,w,b): # 损失函数
loss = (y_pred - y).pow(2).sum()
for param in [w, b]:
if not param.grad is None: param.grad.data.zero_() # 清除累计梯度
loss.backward()
return loss.data, w,b
def optimize(learning_rate, w, b): # 用梯度下降优化权重
w.data -= w.grad.data.mul(learning_rate)
b.data -= b.grad.data.mul(learning_rate)
return w,b
def main():
x,y = get_data()
w,b = get_weights()
time_s = time.time()
for i in range(100000):
y_pred = simple_network(x, w, b)
loss, w, b = loss_fn(y, y_pred, w, b)
if i % 10 == 0:
w, b = optimize(0.0001, w, b)
if torch.cuda.is_available() and use_cuda:
print("loss:{}".format(loss.cpu().data.numpy()))
else:
print("loss:{}".format(loss.data.numpy()))
print("*"*20)
if torch.cuda.is_available() and use_cuda:
print('w: {:.4f}, b: {:.4f}'.format(w.cpu().data.numpy()[0], b.cpu().data.numpy()[0]))
else:
print('w: {:.4f}, b: {:.4f}'.format(w.data.numpy()[0], b.data.numpy()[0]))
print("speed time:{:.2f}".format(time.time()-time_s))
if __name__=='__main__':
main()
pytorch学习之线性拟合
最新推荐文章于 2023-09-17 21:05:52 发布