代码:
##train.py##
.....
from torch.utils.tensorboard import SummaryWriter
.......
#在main方法里创建和关闭writer对象
#将writer传递给Solver类
示例的train方法定义在solver.py的Solver训练类里,如果你是直接使用train方法,不必传递
def main(args):
#数据准备
# 构建数据字典
data =
model= #创建模型
writer = SummaryWriter(log_dir=args.save_folder) # 指定日志保存目录
..............
solver = Solver(data, model, optimizer, scheduler, args, writer)
solver.train()
writer.close() # 关闭tensorboard
##solver.py##
class Solver(object):
def __init__(self, data, model, optimizer, scheduler, args, writer):
self.args = args
..........
self.writer = writer #接收 writer 对象
# Training config
# save and load model
# logging
def train(self):
train_loss = []
val_loss = []
.........
#训练验证
#计算损失
#模型图的代码
self.writer.add_graph(self.model, torch.randn(batchnum, channels, length).cuda() if self.use_cuda else torch.randn(batchnum, channels, length))
环境:
包管理器下:
pip install tensorboard
pip install torch-tb-profiler
注意:python解释器要是3.8以上,不然会出现:
from typing import TYPE_CHECKING, Generic, Iterator, NamedTuple, TypeVar, TypedDict, overload ImportError: cannot import name 'TypedDict' from 'typing'
低版本 Python 环境中,typing
模块中没有TypedDict
。这个问题通常发生在 Python 版本较旧的时候。
当你的项目必须使用3.7以下的python版本时,可以在该环境下训练好你的网络,得到文件
然后切换python3.8的环境启动tensorboard也是可以的。
vscode下:
代码准备好后,vscode会通知安装或者启动tensorboard
使用:
在代码中直接点击启动:
vscode下会出现:
出现这个后:可以再在浏览器里输入:http://localhost:6006/
仍然是一样的结果,并且还可以下载图片(vscode扩展下好像点击下载没反应:
更新:经过试验,前面的方法还是太麻烦了:
##train.py##
.....
from torch.utils.tensorboard import SummaryWriter
.......
#在main方法里创建和关闭writer对象
def main(args):
#数据准备
# 构建数据字典
data =
model= #创建模型
writer = SummaryWriter(log_dir=args.save_folder) # 指定日志保存目录
#模型图的代码
writer.add_graph(model, torch.randn(batchnum, channels, length).cuda() if args.use_cuda else torch.randn(batchnum, channels, length))
..............
与训练类无关,只要建立好模型就可以直接用了,需要注意的是,输入数据和模型权重要在一个设备上(cpu或gpu)
writer.close() # 关闭tensorboard