1、安装
在pytorch的环境下,安装tensorboard和tensorflow
conda install tensorboard
conda install tensorflow
我安装在虚拟环境d2l中
验证是否安装成功
from torch.utils.tensorboard import SummaryWriter
输入如上代码,如果没有报错,则代表安装成功
2、简单实例
在pytorch官网找到实例
创建日志
from torch.utils.tensorboard import SummaryWriter
#创建1个名为“logs”的日志
writer = SummaryWriter("logs")
x = range(100)
for i in x:
#写入y=2*x的曲线
writer.add_scalar('y=2x', i * 2, i)
#关闭
writer.close()
运行结束后,已经创建了1个日志文件
消费日志
打开终端控制台
在控制台,检查是否路径正确。正确的话,输入以下,并回车。logs是日志名
tensorboard --logdir=logs
点击链接,查看即可
注意:不能同时查看多个日志,容易出错
3、画模型函数--以resnet为例
from torch.utils.tensorboard import SummaryWriter
import d2l.torch
import torch
from torch.nn import functional as F
from torch import nn
class Residual(nn.Module): #@save
def __init__(self,input_channels,num_channels,use_1x1conv=False,strides=1):
super().__init__() self.conv1=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=strides)
self.conv2=nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1)
if use_1x1conv:
self.conv3=nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)
else:
self.conv3=None
self.bn1=nn.BatchNorm2d(num_channels)
self.bn2=nn.BatchNorm2d(num_channels)
def forward(self,x):
y=F.relu(self.bn1(self.conv1(x)))
y=self.bn2(self.conv2(y))
if self.conv3:
x=self.conv3(x)
y+=x
return F.relu(y)
b1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=64,kernel_size=7,padding=3,stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3,padding=1,stride=2))#conv2d尺寸减半,池化也
def resnet_block(input_channels,num_channels,num_residuls,first_block=False):
block = []
for i in range(num_residuls):
if i==0 and not first_block:
block.append(
Residual(input_channels,num_channels,use_1x1conv=True,strides=2))
else:
block.append(
Residual(num_channels,num_channels))
return block
b2 = nn.Sequential(*resnet_block(64,64,2,True))
b3 = nn.Sequential(*resnet_block(64,128,2,False))
b4 = nn.Sequential(*resnet_block(128,256,2,False))
b5 = nn.Sequential(*resnet_block(256,512,2,False))
resnet = nn.Sequential(b1,b2,b3,b4,b5,
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(),
nn.Linear(in_features=512,out_features=10))
#以上是resnet网络的定义
writer = SummaryWriter("model_logs")
#创建1个名为“model_logs”日志
X = torch.randn(size=(1,1,224,224))
writer.add_graph(model=resnet,input_to_model=X)
#将输入X代入模型,写进日志
writer.close()
#关闭日志
成功创建名为“model_logs”的日志
打开终端,点击本地链接查看网络
会显示下图所示的大致框架
双击加号之后,会展现出具体的网络细节 ,并且显示了每阶段向量形状。每一步骤都可以继续深化展现