莫烦Pytorch系列教程3.1----关系拟合(回归 Regression)

import torch 
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F

x= torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
#x变为100 * 1的行向量

y= x.pow(2)+0.2*torch.rand(x.size())
#rand()是0到1的均匀分布
x,y = Variable(x),Variable(y)

#打印散点图
#plt.scatter(x.data.numpy(),y.data.numpy())
#


#定义神经网络逼近y=x*x
class Net(torch.nn.Module):#继承torch.nn.Module
    #每一个神经网络都会包含这两个功能
    def __init__(self,n_features,n_hiddens, n_outputs):#官方套路
        super(Net,self).__init__()#继承_init_()功能,官方套路
        #定义神经网络的每一层,比如说输入输出节点数
        self.hidden = torch.nn.Linear(n_features,n_hiddens)
        #隐藏层的输入输出节点
        self.predict = torch.nn.Linear(n_hiddens,n_outputs)
    
    def forward(self,x):#官方套路
        #前向传播,神经网络每一层的组合,
        #搭建神经网络
        #x是输入信息
        x=F.relu(self.hidden(x))
        x= self.predict(x)
        return x
    #本次的预测输出是不需要激活函数

#创建神经网络
net = Net(n_features=1,n_hiddens=10,n_outputs=1)
#一个节点输入,10个隐藏层,1个输出

#打印网络结构
print(net)


#实时打印过程开启
plt.ion()#Turn the interactive mode on
plt.show() #Display all figures.


optimizer = torch.optim.SGD(net.parameters(),lr = 0.5)
#SGD优化,并且传入net的所有参数,学习率是0.5
loss_func = torch.nn.MSELoss()
#预测值与真实值的误差计算公式,均方差

for t in range(100):
    prediction = net(x)
    #喂数据,输出计算值
    loss = loss_func(prediction,y) #计算两者误差
    optimizer.zero_grad()#清空上一步的残余更新值
    loss.backward()#开始反向传播
    optimizer.step()#将参数更新施加到net的parameters上
    
    #可视化训练过程
    if t % 5 ==0:
        #plot and show learning process
        plt.cla()
        plt.scatter(x.data.numpy(),y.data.numpy())
        plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
        plt.text(0.5, 0, 'Loss=%.4f' % loss.item(), fontdict={'size': 20, 'color':  'red'})
        plt.title("Regression Plot")
        plt.pause(0.1)
plt.ioff()#关闭实时交互过程
plt.show()#显示最后的图像
    
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值