Day9 神经网络-卷积层conv2d的使用


前言

去官方文档查看torch.nn中卷积层的相关介绍,包括参数、使用方法以及卷积核的原理
在这里插入图片描述


一、CONV2D

1.参数介绍

在这里插入图片描述
in_channels:输入的通道数;最初的输入一般都是3(RGB),或者1(灰度图);之后每一层输入通道数都应为前一层的输出通道数

out_channels:输出通道数

kernel_size:巻积核大小(上图中所展示的就是3*3,填写参数时仅需写一个3即可)

stride: 步长,卷积核在图像上每次移动的距离,上图为1

padding_mode:在图像周围补充0;默认padding_mode=0;就是不填充0

dilation:正常情况下为1;为2则代表空洞卷积

bias:偏置;卷积可以想象成是做了一次乘法,偏执就是作一次加法

2.相关公式

(1)卷积过程
在这里插入图片描述
在这里插入图片描述

(2)卷积前后由于参数的设置导致尺寸变化:
在这里插入图片描述

N:batch_size

C:通道数

H,W:图像的高和宽

self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)

按照上面的公式输出如下:

H=(3+20-1(3-1)-1)/1 +1=2

W=(3+20-1(3-1)-1)/1 +1=2

二、卷积实例操作过程

1.数据集准备和加载

依旧以CIFAR10数据集为例,下载dataset后作为参数传递给dataloader加载,注意batch_size选取64为一组,其他参数默认

dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64)

2.开始搭建简单神经网络

  • 继承父类Module为主框架搭建,初始化函数别忘了引用父类的初始化
  • 设置卷积层conv1使用Conv2d方法,具体参数已经介绍,在这里选取输入数据通道为3,输出通道为6,卷积过程由forward前向传播调用
  • 注意卷积层定义语句,多练习
class Www(nn.Module):
   def __init__(self):
       super(Www, self).__init__()
       # 定义卷积层
       self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)

   def forward(self, x):
       x = self.conv1(x)
       return x

3.实例化Www网络并输出查看

www = Www()
print(www)

输出结果中可以看到卷积层(conv1)的各项参数,简单卷积层搭建完毕。
在这里插入图片描述

4.把dataloader中加载好的数据都放入卷积层中处理,并查看卷积前后格式变化

之前在数据的读取章节介绍过如何将dataloader加载好的数据读取

for data in dataloader:
   imgs, targets = data
   output = www(imgs)
   # 观察卷积前后数据的格式变化
   print(imgs.shape)
   print(output.shape)

输出结果:
在这里插入图片描述
观察到batch_size都是64(dataloader时设置的batch_siaze)没有变化,输入是3channel,输出是6channel,通道数多了,但是尺寸变小了,这就是本次卷积的作用和处理结果

5.结合Tensorboard(含报错解决办法)

1.代码实现

# 结合tensorboard,观察卷积前后图片变化
writer = SummaryWriter("../logs")
step = 0
# 把dataloader中加载好的数据都放入卷积层中处理,并查看卷积前后格式变化
for data in dataloader:
   imgs, targets = data
   output = www(imgs)
   # 观察卷积前后数据的格式变化
   print(imgs.shape)
   print(output.shape)

   writer.add_images("input", imgs, step)
   writer.add_images("output", output, step)

   step = step + 1

2.运行报错及解决方案

在这里插入图片描述

根据路径找到报错位置,原因是在经过卷积层处理后,输出output中的图片通道由3个通道变成了6个通道,但彩色图片(RGB)通道数为3,在写入tensorboard时无法显示,发生报错。

解决方法:torch.Size([64,6,30,30]) -> [xxx,3,30,30]

output = torch.reshape(output, (-1, 3, 30, 30))
writer.add_images("output", output, step)

三、输出结果展示

1.input展示结果

在这里插入图片描述

2.output展示结果

在这里插入图片描述
在这里插入图片描述

四、完整代码

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 数据集准备和加载
dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64)

# 开始搭建简单神经网络
class Www(nn.Module):
    def __init__(self):
        super(Www, self).__init__()
        # 定义卷积层
        self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        return x

# 初始化Www网络
www = Www()
# print(www)

# 结合tensorboard,观察卷积前后图片变化
writer = SummaryWriter("logs")
step = 0
# 把dataloader中加载好的数据都放入卷积层中处理,并查看卷积前后格式变化
for data in dataloader:
    imgs, targets = data
    output = www(imgs)
    # 观察卷积前后数据的格式变化
    # print(imgs.shape)
    # print(output.shape)

    writer.add_images("input", imgs, step)

    # 报错解决办法:torch.Size([64,6,30,30]) -> [xxx,3,30,30]
    output = torch.reshape(output, (-1, 3, 30, 30))
    writer.add_images("output", output, step)

    step = step + 1
    
writer.close()

五、总结心得

本次学习之后使我认识到阅读官方文档的重要性
回顾了dataset和dataloader数据读取、tensorboard的使用、神经网络主框架Module
学习了卷积层的conv2d的相关概念和机制,对参数的设置有了清晰的理解,以及如何将dataloader中的数据读取并送入神经网络经过卷积层处理再展示到tensorboard中

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值