【PyTorch深度学习实践】第6讲 逻辑回归

在这里插入图片描述
损失函数前面为什么加负号?
交叉熵是越大越好,损失值习惯上看的是越小越好,因此添加了负号。
交叉熵是用来判断实际的输出和期望的输出的接近程度。经常用在神经网络分类问题,也经常作为损失函数。

# 逻辑回归

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt


# 大于3.0为1
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0.0], [0.0], [1.0]])
# 分类


class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

    # 通过sigmoid计算y_hat 预测值
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred


model = LogisticRegressionModel()

criterion = torch.nn.BCELoss(reduction='sum')
# 优化器 model.parameters()获取模型中需要优化的参数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

Epoch_list = []  # 保存epoch
Loss_list = []  # 保存每个epoch对应的loss

# 训练
for epoch in range(100):
    # 前馈
    y_pred = model(x_data)
    # 损失
    loss = criterion(y_pred, y_data)
    Epoch_list.append(epoch)
    Loss_list.append(loss)
    print(epoch, loss.item())

    # 梯度清零
    optimizer.zero_grad()
    # 反向传播
    loss.backward()
    # 更新优化参数 w和b
    optimizer.step()

    print('w=', model.linear.weight.item())
    print('b=', model.linear.bias.item())


# 绘图
#
# plt.plot(Epoch_list, Loss_list)
# plt.title("BCE")
# plt.ylabel('Loss')
# plt.xlabel('Epoch')
# plt.grid(ls='--')
# plt.show()
#

# 测试模型
x_test = torch.Tensor([[5.0]])
y_test = model(x_test)
print('y_pred=', y_test.data)


# start=0 end=10 样本数量200
x = np.linspace(0, 10, 200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid(True)
plt.show()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值