msg = ae.create_training_task('training_name', '/path/log')
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
loss, accuracy = test(model, device, test_loader)
print('test_loss=', loss)
print('test_accuracy=',accuracy)
tb.add_scalar('Loss',loss, epoch)
tb.add_scalar('Accuracy', accuracy, epoch)
scheduler.step()
if args.save_model:
torch.save(model, "./mnist_cnn.pht")