回顾一下使用tensorboard的过程
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn
writer = SummaryWriter("logs")
class MLP(nn.Module):
def __init__(self):
super(MLP,self).__init__()
self.Net = nn.Sequential(
nn.Linear(784,512),
nn.ReLU(),
nn.Linear(512,128),
nn.ReLU(),
nn.Linear(128,10)
)
def forward(self,input):
input = input.view(-1,28*28)
return self.Net(input)
model = MLP()
input = torch.FloatTensor(np.random.rand(32,1,28,28))
writer.add_graph(model,input)
writer.close()
简单写一个神经网络模型MLP,命令行中运行tensorboard
tensorboard --logdir=logs
查看网络结构