pytorch&tf等 深度学习实验管理工具(Sacred)

Sacred 是一个 Python 库, 可以帮助研究人员配置、组织、记录和复制实验。
官方文档
official github

1. 简单介绍

在这里插入图片描述
上图来自该博客
从大佬的实验来看,scared是一个可以在任意框架下使用的python工具。其在参数管理方面非常优秀,但是前端显示比较弱势。

2. 案例使用

在这里插入图片描述
如上图,sacred的使用非常简单,还有很多记录实验的特性见官方文档。安装只需(或者手动去github下载文件):

pip install sacred

以下代码是一个较完整的训练代码

from sacred import Experiment
from sacred.observers import MongoObserver
from sacred.utils import apply_backspaces_and_linefeeds
from sacred.observers import FileStorageObserver

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


ex = Experiment("mnist_cnn")
#ex.observers.append(MongoObserver.create(url='localhost:27017', db_name='sacred'))
# 这里使用了数据库(可以不采用,采用本地文件记录FileStorageObserver),如下: 
ex.observers.append(FileStorageObserver('my_exp'))
ex.captured_out_filter = apply_backspaces_and_linefeeds # 过滤非标准输出(tqdm)


# 超参数设置
@ex.config
def myconfig():
    # Device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # Hyper parameters
    num_epochs = 5
    num_classes = 10
    batch_size = 100
    learning_rate = 0.001


# Convolutional neural network (你的网络)
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.fc = nn.Linear(7 * 7 * 32, num_classes)

    def forward(self, x):
       
        out = self.fc(out)
        return out


# 自动解析myconfig()参数,直接导入并运行
@ex.automain
def main(_run,device,num_epochs,num_classes,batch_size,learning_rate):
    # MNIST dataset
    train_dataset = torchvision.datasets.MNIST(root='/home/ubuntu/Datasets/MINIST/',
                                               train=True,
                                               transform=transforms.ToTensor(),
                                               download=True)

    test_dataset = torchvision.datasets.MNIST(root='/home/ubuntu/Datasets/MINIST/',
                                              train=False,
                                              transform=transforms.ToTensor())

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    model = ConvNet(num_classes).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ## _run.log_scalar 加入metric信息
			_run.log_scalar('training.loss', loss, epoch)
    		
            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))

    # Test the model
    model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        # 加入测试信息
		_run.log_scalar('test.correct', (100 * correct / total))
        print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

    # Save the model checkpoint
    torch.save(model.state_dict(), 'model.ckpt')

3. 结果展示

在这里插入图片描述
如果配置了数据库和前端可视化工具 MongoDB + Omniboard(others),有更好的体验。以下是两篇安装和介绍博客:
博客1
博客2
在这里插入图片描述
上图来自博客2。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值