几个常见的模型记录插件介绍
Tensorboard
记录数据
记录简单数据的方法
1.add_scalar() 记录标量
2.add_scalars() 记录多个值
在谷歌Colaboratory里跑以下代码
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import numpy as np
np.random.seed(47) # 固定随机数
# ----------------------------------- 0 SummaryWriter -----------------------------------
# flag = 0
flag = 1
if flag:
max_epoch = 100
writer = SummaryWriter( comment='test_comment', filename_suffix="test_suffix")
for x in range(max_epoch):
writer.add_scalar('y=2x', x * 2, x)
writer.add_scalar('y=pow_2_x', 2 ** x, x)
writer.add_scalars('data/scalar_group', {
"xsinx": x * np.sin(x), "xcosx": x * np.cos(x)}, x)
writer.close()
接着运行
# Load the TensorBoard notebook extension
%load_ext tensorboard
%tensorboard --logdir='./runs'
就能看到以下画面
进行模型监控
for epoch in range(MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):
iter_count += 1
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
# 记录数据,保存于event file
writer.add_scalars("Loss", {
"Train": loss.item()}, iter_count)
writer.add_scalars("Accuracy", {
"Train": correct / total}, iter_count)
# 每个epoch,记录梯度,权值
for name, param in net.named_parameters():
writer.add_histogram(name + '_grad', param.grad, epoch)
writer.add_histogram(name + '_data', param, epoch)
scheduler.step() # 更新学习率
# validate the model
if (epoch+1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1