人工智能-作业3:例题程序复现 PyTorch版

人工智能-作业3:例题程序复现 PyTorch版

1.使用pytorch复现课上例题

运行代码:

import torch

x1, x2 = torch.Tensor([0.5]), torch.Tensor([0.3])
y1, y2 = torch.Tensor([0.23]), torch.Tensor([-0.07])
print("=====输入值:x1, x2;真实输出值:y1, y2=====")
print(x1, x2, y1, y2)
w1, w2, w3, w4, w5, w6, w7, w8 = torch.Tensor([0.2]), torch.Tensor([-0.4]), torch.Tensor([0.5]), torch.Tensor(
    [0.6]), torch.Tensor([0.1]), torch.Tensor([-0.5]), torch.Tensor([-0.3]), torch.Tensor([0.8])  # 权重初始值
w1.requires_grad = True
w2.requires_grad = True
w3.requires_grad = True
w4.requires_grad = True
w5.requires_grad = True
w6.requires_grad = True
w7.requires_grad = True
w8.requires_grad = True


def sigmoid(z):
    a = 1 / (1 + torch.exp(-z))
    return a


def forward_propagate(x1, x2):
    in_h1 = w1 * x1 + w3 * x2
    out_h1 = sigmoid(in_h1)  # out_h1 = torch.sigmoid(in_h1)
    in_h2 = w2 * x1 + w4 * x2
    out_h2 = sigmoid(in_h2)  # out_h2 = torch.sigmoid(in_h2)

    in_o1 = w5 * out_h1 + w7 * out_h2
    out_o1 = sigmoid(in_o1)  # out_o1 = torch.sigmoid(in_o1)
    in_o2 = w6 * out_h1 + w8 * out_h2
    out_o2 = sigmoid(in_o2)  # out_o2 = torch.sigmoid(in_o2)

    print("正向计算:o1 ,o2")
    print(out_o1.data, out_o2.data)

    return out_o1, out_o2


def loss_fuction(x1, x2, y1, y2):  # 损失函数
    y1_pred, y2_pred = forward_propagate(x1, x2)  # 前向传播
    loss = (1 / 2) * (y1_pred - y1) ** 2 + (1 / 2) * (y2_pred - y2) ** 2  # 考虑 : t.nn.MSELoss()
    print("损失函数(均方误差):", loss.item())
    return loss


def update_w(w1, w2, w3, w4, w5, w6, w7, w8):
    # 步长
    step = 1
    w1.data = w1.data - step * w1.grad.data
    w2.data = w2.data - step * w2.grad.data
    w3.data = w3.data - step * w3.grad.data
    w4.data = w4.data - step * w4.grad.data
    w5.data = w5.data - step * w5.grad.data
    w6.data = w6.data - step * w6.grad.data
    w7.data = w7.data - step * w7.grad.data
    w8.data = w8.data - step * w8.grad.data
    w1.grad.data.zero_()  # 注意:将w中所有梯度清零
    w2.grad.data.zero_()
    w3.grad.data.zero_()
    w4.grad.data.zero_()
    w5.grad.data.zero_()
    w6.grad.data.zero_()
    w7.grad.data.zero_()
    w8.grad.data.zero_()
    return w1, w2, w3, w4, w5, w6, w7, w8


if __name__ == "__main__":

    print("=====更新前的权值=====")
    print(w1.data, w2.data, w3.data, w4.data, w5.data, w6.data, w7.data, w8.data)

    for i
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值