人工智能学习07--pytorch12--AlexNet+花数据集+pytorch

文章详细介绍了AlexNet网络结构,包括卷积层、MaxPooling和全连接层的使用。通过PyTorch搭建网络时,提到了padding的处理、ReLU激活函数以及权重初始化。在训练花分类数据集时,文章讨论了数据划分、批量大小和后端设置问题,并解决了训练过程中遇到的错误。
摘要由CSDN通过智能技术生成

AlexNet详解

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
这个网络可以看成上下两部分,因为作者使用了两块GPU进行并行运算。
为了便于理解,看其中一部分就行,因为这两部分是一样的。

  1. Conv1
    在这里插入图片描述
    在这里插入图片描述
    (1+2)代表的是padding的大小,虽然图中没有给出,但是自己可以推理出来,也可以从源码里面发现。
    1:在特征矩阵左边加上一列0,在右边加上两列0,上面加上一列0,下面加上两列0 。

(括号弹幕大佬:
2p就等于两边padding的像素之和,不一定严格要求两边padding的像素个数一样


其实这里双gpu和单gpu卷积层和下采样一致,这是因为batch只是增加厚度,而做卷积只会影响高宽。只是经过全链接的线性层,需要batch数据交互,后面才会互相影响。

  • MaxPool1
    Maxpooling下采样
    在这里插入图片描述
    池化操作只会改变输出矩阵的高度和宽度,不会改变特征矩阵的深度

  • Conv2
    在这里插入图片描述

  • MaxPool2
    在这里插入图片描述

  • Conv3
    在这里插入图片描述

  • Conv4
    在这里插入图片描述

  • Conv5
    在这里插入图片描述

  • Maxpool3
    在这里插入图片描述

  • 3个全连接层
    将Maxpooling得到的输出展平,与这三个层进行连接。
    最后一个层有1000个节点,因为论文用到的数据集有1000个类别。

花分类数据集

  • 下载数据集
    五个类别在这里插入图片描述
    下载地址https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    在这里插入图片描述
    在这里插入图片描述
  • 训练集、测试集
    使用脚本将数据按照9:1的比例划分为训练集和验证集
    在这里插入图片描述
    回到文件夹检查,发现划分成功
    在这里插入图片描述

使用pytorch搭建并训练花分类网络

1、定义组件

在这里插入图片描述
padding[1,2]是左边1右边2
在这里插入图片描述
Sequential:之前每使用一个层结构,都要self.模块名称=……,如果每个模块都这么定义,那么工作量太大。
对网络层次比较多的网络,可以使用Sequential函数来精简代码

  • Conv2d
    关于padding:
    ①整数:如1,则会在矩阵上下左右补一行/列0;
    ②tuple:
    在这里插入图片描述
    ③:nn.ZeroPad2d 精确补0
    在这里插入图片描述
    在卷积过程中,如果按照公式计算得到的不是整数:
    在这里插入图片描述
  • ReLU
    inplace:pytorch中增加计算量,但能减小内存容量的方法
  • Maxpool1
    在这里插入图片描述
  • 代码的卷积核个数都弄成论文中的一半了
    在这里插入图片描述
    下一步就是跟全连接层连接
  • 三个全连接层
    drop使全连接层的节点按照一定比例失活
    在这里插入图片描述
    init_weights为True则进入初始化权重函数
    self.module
    在这里插入图片描述
    返回一个迭代器。这个迭代器会遍历网络中所有的模块–>则这个迭代器会遍历网络中每一个层结构:
    在这里插入图片描述
    此处初始化权重的方法:
    在这里插入图片描述
    pytorch里面自动使用上面这个初始化方法

2、定义正向传播

在这里插入图片描述

train

在这里插入图片描述
在这里插入图片描述
好,成功了
在这里插入图片描述
太感动了。
在这里插入图片描述
jason这样:
在这里插入图片描述
查看数据集:
把batch_size设置成4,因为是查看四张图片。把后面的shuffle随机打乱改成True,否则一直按顺序读取那就是同一文件夹中的同一种类。
在这里插入图片描述
但是报错了。
在这里插入图片描述
这样改之后好了,但是有新的错误
在这里插入图片描述
MatplotlibDeprecationWarning: Support for FigureCanvases without a required_interactive_framework attribute was deprecated in Matplotlib 3.6 and will be removed two minor releases later.

找了许多方法,最后发现是Matplotlib的后端问题。一开始输出后端:
在这里插入图片描述
按照网页说法:
https://blog.csdn.net/m0_37724919/article/details/128874187(不是我这个问题,但是先码住,说不定以后换到服务器会用到)

https://blog.csdn.net/Reginasong/article/details/128841357(解决问题)
在这里插入图片描述
打印出来的是这个:
module://backend_interagg
感觉不对,按照网上的,把后端改成了:
在这里插入图片描述
解决问题。
在这里插入图片描述
注释掉这段代码之后把之前的地方改回去,继续下一步。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

validate

在这里插入图片描述

main

在这里插入图片描述
在这里插入图片描述

predict

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

遇到了这个问题:
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

使用自己的数据集训练网络

修改网络参数
train:
在这里插入图片描述
predict:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值