pytorch搭建神经网络-简化版代码

import torch
import numpy as np

# 创建数据集: 每个样本有14个特征
x_train = np.array([
    [0.5, -1.2, 0.3, 0.8, 1.0, -0.5, 2.3, 1.2, -0.3, 1.5, -1.1, 0.6, -0.8, 0.7],
    [1.5,  2.2, 1.3, -0.7, 1.1,  0.5, -1.3, 0.4,  1.2, 0.8,  0.3, 0.6,  2.1, 0.2],
    [0.9, -0.2, -0.5, -1.2, 1.3, -1.1, 0.7,  1.5,  0.9, 1.0, -0.4, 0.5, -1.0, 1.4],
    [-0.4, 0.8, 1.2, -0.1, 1.5, 0.2, 0.6, -1.3, 1.0, 1.3, 0.3, -0.9, 1.1, 0.5],
    [1.0, 0.2, -1.4, 0.3, -0.7, 1.1, -0.1, 0.5, 0.6, 1.5, 0.7, -0.5, 0.9, -0.2]
], dtype=np.float32)

y_train = np.array([[5.0], [6.0], [4.0], [7.0], [3.0]], dtype=np.float32)



input_size = 14  #输入层神经元个数
hidden_size = 128 #隐藏层神经元个数
output_size = 1  #输出层神经元个数
batch_size = 16
my_nn = torch.nn.Sequential(
    torch.nn.Linear(input_size,hidden_size),
    torch.nn.Sigmoid(),
    torch.nn.Linear(hidden_size,output_size)
)
cost = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(my_nn.parameters(),lr=0.001)


#训练网络
losses=[]
for i in range(1000):
    batch_loss=[]
    # 分批训练
    for start in range(0,x_train.shape[0],batch_size):
        end = start + batch_size if start+batch_size<x_train.shape[0] else len(x_train)  #开始的位置每次跳16个步长,结束的位置为开始位置+16,特判有没有超越界限
        xx = torch.tensor(x_train[start:end],dtype=torch.float32,requires_grad=True)
        yy = torch.tensor(y_train[start:end],dtype=torch.float32,requires_grad=True)

        # 前向传播不需要自己一个矩阵一个矩阵相乘了,可以直接输入到my_nn中
        prediction = my_nn(xx)
        loss = cost(prediction,yy)  # 损失函数
        loss.backward(retain_graph=True) # 反向求导
        optimizer.step() # 更新参数
        optimizer.zero_grad()
        batch_loss.append(loss.detach().numpy())
    if i %100==0:
        losses.append(np.mean(batch_loss))
        print(i,np.mean(batch_loss))


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

背水

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值