pytorch系列(二):pytorch分类算法简单实现

#该案例主要是做简单节点分类任务的模型训练
import torch
import matplotlib.pyplot as plt

#数据集构造  都是tensor张良结构
n_data=torch.ones(100,2)#数据基本形态  100x2的维度
x0=torch.normal(2*n_data,1)#第一种类型的数据,服从为1的正态分布 100x2的维度 横坐标和纵坐标
y0=torch.zeros(100)#第一类数据的标签  0  100x1的维度

x1=torch.normal(-2*n_data,1)#第二种类型的数据,服从为1的正态分布
y1=torch.ones(100)#标签为1

#合并训练集和标签集
x=torch.cat((x0,x1),0).type(torch.FloatTensor)#0表示按列合并   合并后列数不变
y=torch.cat((y0,y1),).type(torch.LongTensor)#表示按行合并  合并后行数不变    注意type的类型

#作图查看数据
plt.scatter(x.data.numpy()[:,0],x.data.numpy()[:,1],lw=5,cmap='RdYlGn')
plt.show()

#建立神经网络
import torch.nn.functional as fun

#声明Net类
class Net(torch.nn.Module):
    #初始化函数
    def __init__(self,n_feature,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden=torch.nn.Linear(n_feature,n_hidden)
        self.output=torch.nn.Linear(n_hidden,n_output)
        
    #前向传播函数
    def forward(self,x):
        x=fun.relu(self.hidden(x))
        x=self.output(x)#跟线性回归不一样的是,这个并不是最终预测结果,还要进行softmax操作
        return x

#实例化网络
net=Net(n_feature=2,n_hidden=10,n_output=2)#几个类别 output就是几
print(net)

#训练网络

#声明优化器
optimizer=torch.optim.SGD(net.parameters(),lr=0.2)#传入net的所有参数,指定学习率
#注意计算误差的时候,真实值并不是one-hot形式的,而是1D的Tensor(softmax得到最大概率索引)
#预测值是2D的tensor (batch,n_classes)

#声明损失函数
loss_fun=torch.nn.CrossEntropyLoss()#用交叉熵计算损失(分类问题经常如此)

#可视化训练过程
plt.ion()#打开plt交互模式
plt.show()

#迭代100次训练
for i in range(2):
    out=net(x)#喂数据集
    
    loss=loss_fun(out,y)#计算交叉熵误差值
    optimizer.zero_grad()#清空上一步残余的参数更新量
    loss.backward()#反向传播,更新参数
    optimizer.step()#参数更新值施加到net的parameters上面
    
    if i%2==0:
        plt.cla()#清除面板信息
        prediction=torch.max(fun.softmax(out),1)[1]#第一个1表示对一个100样本softmax后按行选取最大值,[1]表示取出最大值对应的类别索引
        pre_y=prediction.data.numpy().squeeze()#将张量形式变为普通numpy形式
        
        target_y=y.data.numpy()
        
        #c=pre_y表示根据分类填充颜色,cmap是只有当c是浮点数数组时使用
        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pre_y, s=100, lw=0, cmap='RdYlGn')
        accuracy = sum(pre_y == target_y)/200.  # 预测中有多少和真实值一样
        plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color':  'red'})
        plt.pause(0.01)
plt.ioff()  # 停止画图
plt.show()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值