pytorch5-SOFTMAX回归手动实现

SOFTMAX回归手动实现

import torch
import torchvision
import numpy as np
import sys
import d2l
batch_size=256
train_iter,test_iter =d2l.load_data_fashion_mnist(batch_size)

num_inputs=784
num_output=10

W=torch.tensor(np.random.normal(0,0.01,(num_inputs,num_outputs)),dtype=torch.float)
b=torch.zeros(num_outputs,dtype=torch.float)

W.required_grad_(requires_grad=True)
b.required_grad_(required_grad=True)

def softmax(X):
	X_exp=X.exp()
	partition=X_exp.sum(dim=1,keepdim=True)
	return X_exp/partition # 这里使用了广播的方法

'''
# 测试
X=torch.rand((2,5))
X_prob=softmax(x)
print(X_prob,X_prob.sum(dim=1))
'''

定义模型

def net(X):
	return softmax(torch.mm(X.view((-1,num_inputs)),W)+b)

定义损失函数

y_hat=torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])
y=torch.LongTensor([0,2])
print(".gather()函数的解释")
print(y_hat.gather(1,y.view(-1,1)))

'''
tensor([[0.1000],
 [0.5000]])'''
#下⾯实现了3.4节(softmax回归)中介绍的交叉熵损失函数。
def cross_entropy(y_hat, y):
    return -torch.log(y_hat.gather(1, y.view(-1, 1)))
 
'''
给定⼀个类别的预测概率分布 y_hat ,我们把预测概率最⼤的类别作为输出类别。如果它与真实类
别 y ⼀致,说明这次预测是正确的。分类准确率即正确预测数量与总预测数量之⽐'''
def accuracy(y_hat,y):
    return (y_hat.argmax(dim=1)==y).float().mean().item()

'''
训练softmax回归的实现跟“线性回归的从零开始实现” ⼀节介绍的线性回归中的实现⾮常相似。我们同
样使⽤⼩批量随机梯度下降来优化模型的损失函数。在训练模型时,迭代周期数 num_epochs 和学习
率 lr 都是可以调的超参数。改变它们的值可能会得到分类更准确的模型'''

num_epochs,lr=5,0.1

d2l.train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,batch_size,[W,b],lr)

#3.6.8 预测
#给定⼀系列图像(第三⾏图像输出),我们⽐较⼀下它们的真实标签(第⼀⾏⽂本输出)和模型预测结果(第⼆⾏⽂本输出)
x,y=iter(test_iter).next()
true_labels=d2l.get_fashion_mnist_labels(y.numpy())
pred_lables=d2l.get_fashion_mnist_labels(net(x).argmax(dim=1).numpy())
titles=[true+'\n'+pred for true,pred in zip(true_labels,pred_lables)]
d2l.show_fashion_mnist(x[0:9],titles[0:9])
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值