使用ray对pytorch模型进行超参数调节

def main(num_samples=10,max_num_epochs=10,gpus_per_trial=1):
    data_dir=os.path.abspath('./data')
    load_data(data_dir)
    config={
        "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
        "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([2, 4, 8, 16])
    }
    scheduler=ASHAScheduler(metric='loss',mode='min',max_t=max_num_epochs,
                           grace_period=1,reduction_factor=2)
    #定义显示的一些指标
    reporter=CLIReporter(
    parameter_columns=['l1','l2','lr','batch_size'],
    metric_columns=['loss','accuracy','training_iteration'])
    #使用偏函数partial,partial(函数,给函数的参数)
#     result=tune.run(partial(train_cifar,data_dir=data_dir),
#                     resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},config=config,
#                    num_samples=num_samples,scheduler=scheduler,progress_reporter=reporter)
    result = tune.run(
        partial(train_cifar, data_dir=data_dir),
        resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter)
    best_trial=result.get_best_trial('loss','min','last')
    print('最好的配置{}'.format(best_trial.config))
    print('最好的验证损失:{}'.format(best_trial.last_result['loss']))
    print('最好的最后验证精度:{}'.format(best_trial.last_result['accuracy']))
    
    best_trained_model=Net(best_trial.config['l1'],best_trial.config['l2'])
    device='cuda:0' if torch.cuda.is_available() else 'cpu'
    best_trained_model.to(device)
    best_checkpoint_dir=best_trial.checkpoint.value
    model_state,optimizer_state=torch.load(os.path.join(best_checkpoint_dir,'checkpoint'))
    best_trained_model.load_state_dict(model_state)
    test_acc=test_accuracy(best_trained_model,device)
    print('最好的测试集精度:{}'.format(test_acc))

pytorch官网超参调节教程Hyperparameter tuning with Ray Tune — PyTorch Tutorials 1.12.0+cu102 documentation

 看了后有些疑惑,去ray的官网查看

Key Concepts — Ray 1.13.0rr

 ray在调参时如果报错Trials did not complete,那就是前面模型的定定义和使用除了问题,导致无法运行

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值