以CIFAR10的moedl为例,其结构图为:
首先编写一个网络结构:
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
class Cifa10(nn.Module):
def __init__(self):
super(Cifa10, self).__init__()
self.model = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(64*4*4, 64), # 图中缺少这一步
Linear(64, 10)) # 10个类别
def forward(self, input):
output = self.model(input)
return output
cifa10 = Cifa10()
input = torch.ones([64, 3, 32, 32])
output = cifa10(input)
print(output.shape)
运行后其输出结果为:
torch.Size([64, 10])
接下来使用tensorboard对其计算过程进行可视化:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("./logs")
writer.add_graph(cifa10, input)
writer.close()
运行后会在当前目录创建一个logs目录,有tensorboard相应的缓存文件。
然后在当前目录下打开terminal(要先进入当前的虚拟环境:conda activate 环境名),输入命令:tensorboard --logdir=logs
,会出现一个网址,ctrl点击进入:
双击其结构进行展开:
展开到可以看到其计算过程,以最后一层的线性层为例: