pytorch tensorboard_Pytorch的模型结构可视化(tensorboard)

在pytorch中,可以导入tensorboard模块,可视化网络结构及训练流程。

下面通过“CNN训练MNIST手写数字分类”的小例子来学习一些可视化工具的用法,只需要加少量代码。

一、tensorboardX的安装

pip 

二、导入tensorboardX

import 

三、搭建模型

#定义超参数
batch_size = 64
learning_rate = 1e-2
num_epoches = 20

#对数据进行预处理
data_tf = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize([0.5],[0.5])]
)


# 定义网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU(True))

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(True))

        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc = nn.Sequential(
            nn.Linear(128 * 4 * 4, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 128),
            nn.ReLU(True),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = CNN()
print(model)

#下载数据集MNIST手写数字训练集
train_dataset = datasets.MNIST(
    root = './data',train=True,transform = data_tf,download = True)
test_dataset = datasets.MNIST(
    root = './data',train = False,transform = data_tf)
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

四、保存模型结构

在这里保存模型的数据流和结构:w.add_graph()

model = CNN()
dummy_input = torch.rand(20, 1, 28, 28)  # 假设输入20张1*28*28的图片
with SummaryWriter(comment='LeNet') as w:
    w.add_graph(model, (dummy_input,))

五、运行代码及可视化

1.运行代码

2.在Pycharm命令行输入

tensorboard --logdir = C:Usershuangxin1PycharmProjectsuntitledruns

注意 tensorboard --logdir= 路径,这里的路径改为runs文件下面生成的文件的完整路径,即:

08d5acc86d52044173c47496beef93dd.png

在浏览器打开命令行生成的地址,可以看到模型图结构:

3278624606a928e83b58146e1f0218b9.png
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值