pytorch(二)

学习莫烦pytorch视频,部分代码进行注释

#classification.py

import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt

import math
import pdb

n_data = torch.ones(100,2)

#pdb.set_trace()
x0 = torch.normal(2*n_data, 1)#生成均值为2,方差为1的tensor
#print(x0)
y0 = torch.zeros(100)
x1 = torch.normal(-2*n_data, 1)
y1 = torch.ones(100)
x=torch.cat((x0,x1),0).type(torch.FloatTensor)#32bit floating#0:竖着拼 1:横着拼
y=torch.cat((y0,y1),0).type(torch.LongTensor)#64bit integer#只有一维
#print(y0)
x, y =Variable(x), Variable(y)
#print(y)
#plt.scatter(x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy(), s=100, lw=0,cmap='RdYlGn')#c:颜色,s:控制点大小,lw:加不加一样,,cmap控制颜色的数组

#plt.show()

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden=torch.nn.Linear(n_feature,n_hidden)
        self.output=torch.nn.Linear(n_hidden,n_output)

    def forward(self, x):
        x=F.relu(self.hidden(x))
        x=self.output(x)
        return x

def myLoss(pre, y):
    s=0
    #pdb.set_trace()
    for i in range(200):
        p1=pre[i][0]
        p2=pre[i][1]
        #p11=math.exp(p1)/(math.exp(p1)+math.exp(p2))
        #p22=1-p11
        tmp = -pre[i][y[i]] +math.log(math.exp(p1)+math.exp(p2)) 
        s=s+tmp
    return s/200

net = Net(2,10,2)
'''
net = torch.nn.Sequential(
    torch.nn.Linear(2,10),
    torch.nn.ReLU(),
    torch.nn.Linear(10,2),
    )
    不用class,直接定义,用法一样
'''
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func=torch.nn.CrossEntropyLoss()

plt.ion()
for i in range(20):
    pre = net(x)

    loss = loss_func(pre, y)#这里的loss不是直接用的,具体实现参见myLoss
    #myloss = myLoss(pre.data.numpy(),y.data.numpy())
    #print(pre)
    #print(loss)
    #print(myloss)
    #break
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%2==0:
        plt.cla()
        _,prediction=torch.max(F.softmax(pre),1)
        pred_y=prediction.data.numpy().squeeze()
        target_y=y.data.numpy()
        plt.scatter(x.data.numpy()[:,0], x.data.numpy()[:,1], c=pred_y,s=100,lw=0,cmap='RdYlGn')
        ac=sum(pred_y==target_y)/200
        plt.text(1.5, -4, 'Accuracy=%.2f'%ac, fontdict={'size':20, 'color':'red'})
        plt.pause(0.5)
plt.ioff()
plt.show()





  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值