一个简单的分类问题案例的pytorch实现

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

n_data = torch.ones(100, 2)

# torch.normal(mean, std, *, generator=None, out=None) → Tensor
# mean 是一个张量,每个输出元素的正态分布均值
# std 是一个张量,每个输出元素的正态分布的标准偏差
x0 = torch.normal(2*n_data, 1)
y0 = torch.zeros(100)  # 第一个分类标签为0
x1 = torch.normal(-2*n_data, 1)
y1 = torch.ones(100)   # 第二个分类标签为1
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)
y = torch.cat((y0, y1),).type(torch.LongTensor)

# matplotlib.pyplot.scatter(x, y, s=None, c=None, marker=None, cmap=None,
# norm=None, vmin=None, vmax=None, alpha=None, linewidths=None,
# verts=None, edgecolors=None, *, data=None, **kwargs)

# x, y: 表示的是大小为(n,)的数组,也就是我们绘制散点图的数据点,输入数据。
# s: 是一个实数或者是一个数组大小为(n,),可选,默认为20。点的面积。
# c: 表示的是颜色,可选。默认是蓝色'b',表示的是标记的颜色,
#    或者是一个表示颜色的字符,或者是一个长度为n的表示颜色的序列等等,
#    但是c不可以是一个单独的RGB数字,也不可以是一个RGBA的序列。可以是他们的二维数组(只有一行)。
# marker: 表示的标记的样式,可选,默认的是'o'
# cmap: Colormap,标量或者是一个colormap的名字,cmap仅仅当c是一个浮点数数组的时候才使用。
#       如果没有申明就是image.cmap,可选,默认为None。
# norm: Normalize,数据亮度在0-1之间,也是只有c是一个浮点数数组的时候才使用。
#       如果没有申明就是color.Normalize,就是默认None
# vmin,vmax: 标量,当norm存在的时候忽略。用来进行亮度数据的归一化,可选,默认None。
# alpha: 标量,0-1之间,可选,默认None,线的透明度
# linewidths: 也就是标记点的长度,默认None,即点的直径

# plt.scatter(x.data[:,0], x.data[:,1], c=y.data, s=100, lw=0, cmap='RdYlGn')
# plt.show()

class Net(torch.nn.Module):
    def __init__(self, n_data, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_data, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)

    def forward(self, x):
        x = torch.relu(self.hidden(x))  #使用ReLU作为激活函数
        x = self.predict(x)
        return x

net = Net(2, 10, 2)  # 类的实例化
print(net)

plt.ion()
plt.show()

optimizer = torch.optim.SGD(net.parameters(), lr=0.03)  # 使用随机梯度下降作为优化器
loss_function = torch.nn.CrossEntropyLoss()  # 使用交叉熵误差来作为损失函数

for t in range(100):
    out = net(x)  # 正向传播

    loss = loss_function(out, y)  # 计算损失值

    optimizer.zero_grad()  # 优化器清零
    loss.backward()  # 反向传播求梯度
    optimizer.step()  # 更新可训练参数
    if t % 2 == 0:
        plt.cla()  # Clear axis即清除当前图形中的当前活动轴。其他轴不受影响
        prediction = torch.max(F.softmax(out, dim=0), 1)[1]  # 输出softmax计算之后的最大值的索引
        pred_y = prediction.data.squeeze()
        target_y = y.data
        plt.scatter(x.data[:, 0], x.data[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')  # 绘制散点图
        accuracy = sum(pred_y==target_y) / 200  # 计算识别精度
        plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})  # 输出一个实时的精度计算文本
        plt.pause(0.05)  # 0.05秒的暂停时间

plt.ioff()
plt.show()

这段代码运行之后在plt.pause(0.05)处会报错,原因还未搞懂,希望各位大佬不吝赐教,谢谢!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值