1.导入需要的库和模块:
import torch
from torch import nn
from torch.nn import Conv2d,MaxPool2d,Flatten,Linear,Sequential
from torch.utils.tensorboard import SummaryWriter
2.定义神经网络的类:
class Sjnet(nn.Module):
def __init__(self):
super(Sjnet, self).__init__()
self.mode=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64,10)
)
def forward(self,x):
x=self.mode(x)
return x
这段代码定义了一个名为"Sjnet"的神经网络类。
在类的构造函数__init__
中,首先调用super(Sjnet, self).__init__()
来继承nn.Module
的属性和方法。
然后,定义了一个名为mode
的Sequential
模型对象。Sequential
是一个容器,按照顺序组合各个层和操作。
-
mode
中的层和操作按照顺序如下:- 第一层是
Conv2d(3, 32, 5, padding=2)
,表示输入通道数为3,输出通道数为32,卷积核大小为5x5,填充为2。 - 接下来是
MaxPool2d(2)
,表示2x2的最大池化操作。 - 然后是
Conv2d(32, 32, 5, padding=2)
,表示输入通道数为32,输出通道数为32,卷积核大小为5x5,填充为2。 - 再次进行2x2的最大池化操作。
- 紧接着是
Conv2d(32, 64, 5, padding=2)
,表示输入通道数为32,输出通道数为64,卷积核大小为5x5,填充为2。 - 再次进行2x2的最大池化操作。
- 然后是
Flatten()
,将输入的多维数据展平为一维向量。 - 接下来是
Linear(1024, 64)
,表示输入大小为1024,输出大小为64的全连接层。 - 最后是
Linear(64, 10)
,表示输入大小为64,输出大小为10的全连接层。
- 第一层是
-
forward
方法定义了数据在网络中的前向传播过程。输入x
首先通过mode
模型进行处理,然后返回输出。
神经网络"Sjnet"的结构是通过一系列的卷积层、池化层和全连接层组成的。它接受输入数据,经过一系列的卷积和池化操作提取特征,然后通过全连接层进行分类或回归等任务。
3.创建网络实例:
sjnet=Sjnet()
通过Sjnet()
创建了一个"Sjnet"神经网络的实例,赋值给变量sjnet
。这行代码将会调用Sjnet
类的构造函数,初始化网络的结构。
使用print(sjnet)
打印了网络的信息。这将输出网络的结构和参数信息,包括每一层的名称、类型和参数数量等。
4.进行前向传播:
sjnet=Sjnet()
print(sjnet)
input=torch.ones((64,3,32,32))
output=sjnet(input)
print(output.shape)
首先,创建了一个大小为(64, 3, 32, 32)
的输入张量input
,其中64是批次大小,3是输入通道数,32x32是输入图像的高度和宽度。
然后,将输入张量input
传递给网络实例sjnet
,进行前向传播。这行代码会调用网络的forward
方法,将输入张量传递给网络模型进行处理。
最后,使用print(output.shape)
打印输出张量output
的形状。这将输出一个元组,表示输出张量的大小。具体输出的形状取决于网络的结构和输入的大小,可以根据输出的形状了解网络的输出维度信息。
其输出结果:
torch.Size([64, 10])
5.TensorBoard中可视化网络的结构:
writer=SummaryWriter('seq_logs')
writer.add_graph(sjnet,input)
writer.close()
使用TensorBoard的SummaryWriter来记录神经网络的图结构,并将其保存到名为'seq_logs'的日志目录中。创建一个SummaryWriter对象,并传递了日志目录的路径作为参数。这里将日志目录命名为'seq_logs',用于存储神经网络的图结构和其他相关信息。使用SummaryWriter对象的add_graph方法,将神经网络模型sjnet
和输入张量input
传递给它。这个方法会将神经网络的图结构以及输入张量的形状信息写入到日志中。