Sacred 是一个 Python 库, 可以帮助研究人员配置、组织、记录和复制实验。
官方文档
official github
1. 简单介绍
上图来自该博客
从大佬的实验来看,scared是一个可以在任意框架下使用的python工具。其在参数管理方面非常优秀,但是前端显示比较弱势。
2. 案例使用
如上图,sacred的使用非常简单,还有很多记录实验的特性见官方文档。安装只需(或者手动去github下载文件):
pip install sacred
以下代码是一个较完整的训练代码
from sacred import Experiment
from sacred.observers import MongoObserver
from sacred.utils import apply_backspaces_and_linefeeds
from sacred.observers import FileStorageObserver
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
ex = Experiment("mnist_cnn")
#ex.observers.append(MongoObserver.create(url='localhost:27017', db_name='sacred'))
# 这里使用了数据库(可以不采用,采用本地文件记录FileStorageObserver),如下:
ex.observers.append(FileStorageObserver('my_exp'))
ex.captured_out_filter = apply_backspaces_and_linefeeds # 过滤非标准输出(tqdm)
# 超参数设置
@ex.config
def myconfig():
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Hyper parameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001
# Convolutional neural network (你的网络)
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.fc = nn.Linear(7 * 7 * 32, num_classes)
def forward(self, x):
out = self.fc(x)
return out
# 自动解析myconfig()参数,直接导入并运行
@ex.automain
def main(_run,device,num_epochs,num_classes,batch_size,learning_rate):
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='/home/ubuntu/Datasets/MINIST/',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='/home/ubuntu/Datasets/MINIST/',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
model = ConvNet(num_classes).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
## _run.log_scalar 加入metric信息
_run.log_scalar('training.loss', loss, epoch)
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
# Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 加入测试信息
_run.log_scalar('test.correct', (100 * correct / total))
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')
3. 结果展示
如果配置了数据库和前端可视化工具 MongoDB + Omniboard(others),有更好的体验。以下是两篇安装和介绍博客:
博客1
博客2
上图来自博客2。