目录
transforms.Normalize(mean, std)
补充知识点
torch.nn.LogSoftmax()
torch.nn.LogSoftmax()和我们的softmax差不多只不过就是最后加入了一个log,关于softmax的详情大家可以看这一篇博客深度学习之感知机,激活函数,梯度消失,BP神经网络
我们只需要知道里面的一些常用参数就好,一般就是这个dim
下面是softmax的logsoftmax和他一样。
torchvision.transforms
这个是一个功能强大的数据处理集合库
这里主要说一下transforms.Compose还有transforms.ToTensor以及transforms.Normalize(mean, std)
transforms.Compose
这个功能函数可以看作一个功能函数容器,它里面可以放多个功能函数(多个的话应该把这些功能函数放在一个列表内),当该功能函数定义好后,其内部的其他功能函数也会随之按我们给定的要求定义好,当我们调用compose的实例的时候,就会按照我们在容器内部摆放的顺序,从左至右的依次调用功能函数
transforms.ToTensor
用于对载入的图片数据进行类型转换,将之前PIL图片的数据(应该是np.array()类型的)转换成Tensor数据类型的张量,以便于我们后续的数据使用
(补充:PIL默认输出的图片格式为 RGB)
(下面图片内容来自网络,侵权必删。)
transforms.Normalize(mean, std)
数据归一化处理。
下面是我在学习过程中遇到的问题:
1.归一化就是要把图片3个通道中的数据整理到[-1, 1]或者[0,1]区间,x = (x - mean(x))/std(x)只要输入数据集x就可以直接算出来,为什么Normalize()函数的mean和std(标准差)还需要我们手动输入数值呢?
我的理解是,我们一开始就算好可以极大的减少运算量,如果我们自动的让他算的话,我们这个每一个图片都要算,这样运算量就极大。
2.RGB单个通道的值是[0, 255],所以一个通道的均值应该在127附近才对。我们接下来的代码如下图所示
所填的是0.5,0.5,这是为什么?
因为我们应用了torchvision.transforms.ToTensor,他会将数据归一化到[0,1](是将数据除以255),transforms.ToTensor( )会把HWC会变成C *H *W(拓展:格式为(h,w,c),像素顺序为RGB),所以我们就应该输入0.5,0.5
(该图片内容来自网络,侵权璧必删)
我们一般来说只需了解一些常用的库函数,我们这里只是提一下我们这次会用到的函数,其余的函数若想了解的话,推荐大家看这位作者写的博客PyTorch之torchvision.tra()nsforms详解[原理+代码实现]-CSDN博客
torchvision.datasets
这里面包含了一些我们目前常用的一些数据集
我们这里主要讲一下mnist
MNIST(手写数字数据集)
MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。
在我们pytorch里面的torchvision里面是有的,我们可以直接用。
torchvision.datasets.MNIST(root, train=True, transform=None, target_transform=None, download=False)
相关参数:
root:就是我们从网上下载的数据集所放的目录,也就是文件路径
train:train=True意思就是下载的是训练集,如果train=Flase,那就是下载的是测试集
transform: 因为我们的数据直接下载过后一般都是要进行处理,所以这个后面跟的是一个transforms数据处理函数的一个实例化对象
download:为True就是从网络上下载数据集保存在我们的root路径中,为False就是不下载。
其余参数可自行查阅
torch.utils.data.DataLoader
(图片内容来自网络,侵权必删)
(其实我们记住常用的,dataset,batch_size,shuffle这三个常用的参数就好了)
torch.nn.NLLLoss()
torch.nn.CrossEntropyLoss()
首先我们要了解交叉熵损失函数,torch.nn.CrossEntropyLoss()。
什么是熵?
熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度。交叉熵越小,表示数据越接近真实样本。
(预测的概率就是我们的预测值的准确值)
torch.nn.NLLLoss()
torch.nn.NLLLoss输入是一个对数概率向量和一个目标标签,它与torch.nn.CrossEntropyLoss的关系可以描述为:
假设有张量x,先softmax(x)得到y,然后再log(y)得到z,然后我们已知标签b,则:
NLLLoss(z,b)=CrossEntropyLoss(x,b)
代码:
nllloss = nn.NLLLoss()
predict = torch.Tensor([[2, 3, 1],
[3, 7, 9]])
predict = torch.log(torch.softmax(predict, dim=-1))
label = torch.tensor([1, 2])
nllloss(predict, label)
运行结果:tensor(0.2684)
而我们用 torch.nn.CrossEntropyLoss
cross_loss = nn.CrossEntropyLoss()
predict = torch.Tensor([[2, 3, 1],
[3, 7, 9]])
label = torch.tensor([1, 2])
cross_loss(predict, label)
运行结果: tensor(0.2684)
enumerate() 函数
这是python的一个内置函数。
enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
Python 2.3. 以上版本可用,2.6 添加 start 参数。
enumerate(sequence, [start=0])
sequence -- 一个序列、迭代器或其他支持迭代对象。
start -- 下标起始位置的值。
例如:
>>> seasons = ['Spring', 'Summer', 'Fall', 'Winter']
>>> list(enumerate(seasons))
[(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]
>>> list(enumerate(seasons, start=1)) # 下标从 1 开始
[(1, 'Spring'), (2, 'Summer'), (3, 'Fall'), (4, 'Winter')]
next() 函数
python的内置函数
next() 返回迭代器的下一个项目。
next() 函数要和生成迭代器的 iter() 函数一起使用。
返回值:返回下一个项目。
next(iterable[, default])
iterable -- 可迭代对象
default -- 可选,用于设置在没有下一个元素时返回该默认值,如果不设置,又没有下一个元素则会触发 StopIteration 异常。
例如:
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# 首先获得Iterator对象:
it = iter([1, 2, 3, 4, 5])
# 循环:
while True:
try:
# 获得下一个值:
x = next(it)
print(x)
except StopIteration:
# 遇到StopIteration就退出循环
break
结果:
1
2
3
4
5
pytorch中.detach()
在 PyTorch 中,detach()
方法用于返回一个新的 Tensor,这个 Tensor 和原来的 Tensor 共享相同的内存空间,但是不会被计算图所追踪,也就是说它不会参与反向传播,不会影响到原有的计算图,这使得它成为处理中间结果的一种有效方式,通常在以下两种情况下使用:
第一:在计算图中间,需要截断反向传播的梯度计算时。例如,当计算某个 Tensor 的梯度时,我们希望在此处截断反向传播,而不是将梯度一直传递到计算图的顶部,从而减少计算量和内存占用。此时可以使用 detach()
方法将 Tensor 分离出来。
第二:在将 Tensor 从 GPU 上拷贝到 CPU 上时,由于 Tensor 默认是在 GPU 上存储的,所以直接进行拷贝可能会导致内存不一致的问题。此时可以使用 detach()
方法先将 Tensor 分离出来,然后再将分离出来的 Tensor 拷贝到 CPU 上。
torch.squeeze()
详情请看这篇博客深度学习之张量的处理(代码笔记)
torch.optim.SGD
torch.optim.SGD 是 PyTorch 中用于实现随机梯度下降(Stochastic Gradient Descent,SGD)优化算法的类。SGD 是一种常用的优化算法。
原理部份可以看我的这篇博客:机器学习优化算法(深度学习)-CSDN博客
我们主要介绍一下常用的参数:
torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)
(图片内容来自网络,侵权必删)
独热编码
就是如果是0-9这十个数,那我们就用[1,0,0,0,0,0,0,0,0,0]表示0,[0,1,0,0,0,0,0,0,0,0]表示1,等等
其余同理
代码实现
关于bp神经网络的原理,我们可以看这篇博客:深度学习之感知机,激活函数,梯度消失,BP神经网络-CSDN博客
我们接下来就只是讲解代码实现。
bp网络的搭建
#搭建bp神经网络
class BPNetwork(torch.nn.Module):
def __init__(self):
super(BPNetwork,self).__init__()
#我们的每张图片都是28*28也就是784个像素点
#第一个隐藏层
self.linear1=torch.nn.Linear(784,128)
#激活函数,这里选择Relu
self.relu1=torch.nn.ReLU()
#第二个隐藏层
self.linear2=torch.nn.Linear(128,64)
#激活函数
self.relu2=torch.nn.ReLU()
#第三个隐藏层:
self.linear3=torch.nn.Linear(64,32)
# 激活函数
self.relu3 = torch.nn.ReLU()
#输出层
self.linear4=torch.nn.Linear(32,10)
# 激活函数
self.softmax=torch.nn.LogSoftmax()
#前向传播
def forward(self,x):
#修改每一个批次的样本集尺寸,修改为64*784,因为我们的图片是28*28
x=x.reshape(x.shape[0],-1)
#前向传播
x=self.linear1(x)#784*128
x=self.relu1(x)
x=self.linear2(x)#128*64
x=self.relu2(x)
x=self.linear3(x)#64*32
x=self.relu3(x)
x=self.linear4(x)#输出层32*10
x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
#上面的这些都可以这几使用x=self.model(x)来代替,为什么能用它,我的理解是,我们继承的class moudle 然后对立面写好的模型框架进行定义,而这个方法就是可以直接调用我们定义好的神经网络
return x
(一些关键点都在注释上)
搭建这次的BP神经网络我的隐藏层有三层,分别是128,64,32个神经元,因为我们的图片是28*28=784得,我们需要把其展开成一维,所以第一层网络是784*128得,这样输入层中每一行代表一个样本(或者说一张图片得所有像素点),因为我们的每个神经元都有一个参数,第一层网络中每一列都是一个神经元对应的参数,所以最后就是n*784和784*128两个矩阵相乘,最后得到n*128得矩阵,以此类推,最后因为我们的输出要用到独热编码的思想,所以我们的输出层调整为10个神经元,或者说最后得线性网络是32*10。
建立我们的神经网络对象
#建立我们的神经网络对象
model=BPNetwork()
#定义损失函数
critimizer=torch.nn.NLLLoss()
#定义优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.003,momentum=0.9)
epochs=15#循环得轮数
#每轮抽取次数的和
a=0
loss_=[]
a_=[]
font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
for i in range(epochs):
# 损失值参数
sumloss = 0
for imges,labels in trainload:
a+=1
#前向传播
output=model(imges)
#反向传播
loss=critimizer(output,labels)
loss.backward()
#参数更新
optimizer.step()
#梯度清零
optimizer.zero_grad()
#损失累加
sumloss+=loss.item()
loss_.append(sumloss)
a_.append(a)
print(f"第{i+1}轮的损失:{sumloss},抽取次数和:{a}")
plt.figure()
plt.plot(a_,loss_)
plt.title('损失值随着抽取总次数得变化情况:',fontproperties=font, fontsize=18)
plt.show()
注意看注释。
SGD还有损失函数等等上面的补充内容里面都有说明,这里不在多加阐述。
预测
#开始预测
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
batch_index,(imagess,labelss)=next(example)
# bath_index=0
# imagess=0
# labelss=0
# for i,j in example:
# bath_index=i
# (imagess, labelss)=j
fig=plt.figure()
for i in range(64):
pre=model(imagess[i])#预测
#第一张图片对应的pre得格式:
# print(pre)
# tensor([[-2.7053e+01, -1.1105e-03, -1.2767e+01, -1.1126e+01, -1.6005e+01,
# -2.0953e+01, -2.3342e+01, -6.8246e+00, -1.2127e+01, -1.6131e+01]],
# grad_fn= < LogSoftmaxBackward0 >)
#接下来我们要用到独热编码的思想,我们取最大的数,也就是最高的概率对应得下标,就相当于这个最高概率对应得独热编码里面的1,其他是0
pro = list(pre.detach().numpy()[0])
pre_label=pro.index(max(pro))
#print(pre_label)
(注意看注释)
我们实际上也可以不用next()我们可以直接用for循环比如这样。
import d2l.torch as d2l
import math
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
traindata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=True,download=True,transform=transform)#训练集60,000张用于训练
testdata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=False,download=True,transform=transform)#测试集10,000张用于测试
#利用DataLoader加载数据集
trainload=DataLoader(dataset=traindata,shuffle=True,batch_size=64)
testLoad=DataLoader(dataset=testdata,shuffle=False,batch_size=64)
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
a=0
images=0
labelss=0
for i,j in example:
a+=1
index = i
(imagess, labelss) = j
print(imagess[0])
print('数据集中抽取的64份数据的纯数据部份的尺寸:',imagess.shape)
print(imagess[0].shape)
print(labelss[0])
print(labelss[0].shape)
if a==1:
break
这样得到的结果我们可以看到:
D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\prictice.py
tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.3412,
0.4510, 0.2471, 0.1843, -0.5294, -0.7176, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, 0.7412,
0.9922, 0.9922, 0.9922, 0.9922, 0.8902, 0.5529, 0.5529,
0.5529, 0.5529, 0.5529, 0.5529, 0.5529, 0.5529, 0.3333,
-0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.4745,
-0.1059, -0.4353, -0.1059, 0.2784, 0.7804, 0.9922, 0.7647,
0.9922, 0.9922, 0.9922, 0.9608, 0.7961, 0.9922, 0.9922,
0.0980, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -0.8667, -0.4824, -0.8902,
-0.4745, -0.4745, -0.4745, -0.5373, -0.8353, 0.8510, 0.9922,
-0.1686, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -0.3490, 0.9843, 0.6392,
-0.8588, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -0.8275, 0.8275, 1.0000, -0.3490,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, 0.0118, 0.9922, 0.8667, -0.6549,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -0.5373, 0.9529, 0.9922, -0.5137, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, 0.0431, 0.9922, 0.4667, -0.9608, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -0.9294, 0.6078, 0.9451, -0.5451, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -0.0118, 0.9922, 0.4275, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-0.4118, 0.9686, 0.8824, -0.5529, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.8510,
0.7333, 0.9922, 0.3020, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.9765, 0.5922,
0.9922, 0.7176, -0.7255, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.7020, 0.9922,
0.9922, -0.3961, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -0.7569, 0.7569, 0.9922,
-0.0980, -0.9922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, 0.0431, 0.9922, 0.9922,
-0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -0.5216, 0.8980, 0.9922, 0.9922,
-0.5922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -0.0510, 0.9922, 0.9922, 0.7176,
-0.6863, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -0.0510, 0.9922, 0.6235, -0.8588,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
数据集中抽取的64份数据的纯数据部份的尺寸: torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
tensor(7)
torch.Size([])
进程已结束,退出代码0
我们可以看到一张图片得数据格式(这里面是已经归一化处理过的),因为我们的手写字体识别是单通道的灰度图,所以size是[1,28,28],这里很正确,对于彩色图三通道得来说,会有些不一样。
完整代码
import matplotlib.pyplot as plt
from matplotlib import font_manager
print('BP识别MNIST任务说明---------------------')
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
#导入数据集并且进行数据处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
traindata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=True,download=True,transform=transform)#训练集60,000张用于训练
testdata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=False,download=True,transform=transform)#测试集10,000张用于测试
#利用DataLoader加载数据集
trainload=DataLoader(dataset=traindata,shuffle=True,batch_size=64)
testLoad=DataLoader(dataset=testdata,shuffle=False,batch_size=64)
#搭建bp神经网络
class BPNetwork(torch.nn.Module):
def __init__(self):
super(BPNetwork,self).__init__()
#我们的每张图片都是28*28也就是784个像素点
#第一个隐藏层
self.linear1=torch.nn.Linear(784,128)
#激活函数,这里选择Relu
self.relu1=torch.nn.ReLU()
#第二个隐藏层
self.linear2=torch.nn.Linear(128,64)
#激活函数
self.relu2=torch.nn.ReLU()
#第三个隐藏层:
self.linear3=torch.nn.Linear(64,32)
# 激活函数
self.relu3 = torch.nn.ReLU()
#输出层
self.linear4=torch.nn.Linear(32,10)
# 激活函数
self.softmax=torch.nn.LogSoftmax()
#前向传播
def forward(self,x):
#修改每一个批次的样本集尺寸,修改为64*784,因为我们的图片是28*28
x=x.reshape(x.shape[0],-1)
#前向传播
x=self.linear1(x)#784*128
x=self.relu1(x)
x=self.linear2(x)#128*64
x=self.relu2(x)
x=self.linear3(x)#64*32
x=self.relu3(x)
x=self.linear4(x)#输出层32*10
x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
#上面的这些都可以这几使用x=self.model(x)来代替,为什么能用它,我的理解是,我们继承的class moudle 然后对立面写好的模型框架进行定义,而这个方法就是可以直接调用我们定义好的神经网络
return x
#建立我们的神经网络对象
model=BPNetwork()
#定义损失函数
critimizer=torch.nn.NLLLoss()
#定义优化器
optimizer=torch.optim.SGD(model.parameters(),lr=0.003,momentum=0.9)
epochs=15
#每轮抽取次数的和
a=0
loss_=[]
a_=[]
font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
for i in range(epochs):
# 损失值参数
sumloss = 0
for imges,labels in trainload:
a+=1
#前向传播
output=model(imges)
#反向传播
loss=critimizer(output,labels)
loss.backward()
#参数更新
optimizer.step()
#梯度清零
optimizer.zero_grad()
#损失累加
sumloss+=loss.item()
loss_.append(sumloss)
a_.append(a)
print(f"第{i+1}轮的损失:{sumloss},抽取次数和:{a}")
plt.figure()
plt.plot(a_,loss_)
plt.title('损失值随着抽取总次数得变化情况:',fontproperties=font, fontsize=18)
plt.show()
#开始预测
example=enumerate(testLoad)#从测试集里面随机抽64份并且记录下来里面的内容和下标
batch_index,(imagess,labelss)=next(example)
# bath_index=0
# imagess=0
# labelss=0
# for i,j in example:
# bath_index=i
# (imagess, labelss)=j
fig=plt.figure()
for i in range(64):
pre=model(imagess[i])#预测
#第一张图片对应的pre得格式:
# print(pre)
# tensor([[-2.7053e+01, -1.1105e-03, -1.2767e+01, -1.1126e+01, -1.6005e+01,
# -2.0953e+01, -2.3342e+01, -6.8246e+00, -1.2127e+01, -1.6131e+01]],
# grad_fn= < LogSoftmaxBackward0 >)
#接下来我们要用到独热编码的思想,我们取最大的数,也就是最高的概率对应得下标,就相当于这个最高概率对应得独热编码里面的1,其他是0
pro = list(pre.detach().numpy()[0])
pre_label=pro.index(max(pro))
#print(pre_label)
#图像显示
img=torch.squeeze(imagess[i]).numpy()
plt.subplot(8,8,i+1)
plt.tight_layout()
plt.imshow(img,cmap='gray',interpolation='none')
plt.title(f"预测值:{pre_label}",fontproperties=font, fontsize=9)
plt.xticks([])
plt.yticks([])
plt.show()
结果及其图像显示
D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\学习过程\第四周的代码\代码一:BP识别MNIST任务说明.py
BP识别MNIST任务说明---------------------
D:\learn_pytorch\学习过程\第四周的代码\代码一:BP识别MNIST任务说明.py:49: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
x=self.softmax(x)#最后输出的数值我们需要利用到独热编码的思想
第1轮的损失:815.2986399680376,抽取次数和:938
第2轮的损失:281.06414164602757,抽取次数和:1876
第3轮的损失:205.76270231604576,抽取次数和:2814
第4轮的损失:159.9431014917791,抽取次数和:3752
第5轮的损失:131.47989665158093,抽取次数和:4690
第6轮的损失:109.93954652175307,抽取次数和:5628
第7轮的损失:95.26143277343363,抽取次数和:6566
第8轮的损失:85.17149402853101,抽取次数和:7504
第9轮的损失:75.1239058477804,抽取次数和:8442
第10轮的损失:68.23363681556657,抽取次数和:9380
第11轮的损失:60.05981844640337,抽取次数和:10318
第12轮的损失:54.82598690944724,抽取次数和:11256
第13轮的损失:51.70861432119273,抽取次数和:12194
第14轮的损失:46.613128249999136,抽取次数和:13132
第15轮的损失:43.05269447225146,抽取次数和:14070
进程已结束,退出代码0
可以看到,经过15轮得训练后,我们基本上已经完全能识别出来了。