python实现类似wandb.sweep搜参功能

该脚本使用argparse处理命令行参数,进行随机搜索以优化模型的dropout、emb_size、learning_rate等超参数。它使用wandb工具记录实验结果,并将最佳性能指标写入CSV文件。训练过程涉及SAINT模型,并针对特定数据集进行。
摘要由CSDN通过智能技术生成
import argparse
from argparse import Namespace
from wandb_train import main
from pprint import pprint
import wandb
import random
import os
import csv


if __name__ == "__main__":
    config_dataset_name=['assist2009','assist2012','assist2017','nips_task34']
    config_dropout=[0.1, 0.2, 0.3,0.4]
    config_emb_size=[64,128,256]
    config_num_attn_heads=[1,2,4,8]
    config_n_blocks=[1,2,4,8]
    config_learning_rate=[0.1,0.01,0.001,0.003,0.005]

    file_path = './result/'+config_dataset_name[3]+'_saint_2.csv'
    if os.path.exists(file_path):
        print("文件存在")
    else:
        file_name='./result/'+config_dataset_name[3]+'_saint_2.csv'
        data = ['best_auc','best_acc','best_epoch', 'dataset_name', 'dropout','emb_size','learning_rate','num_attn_heads','n_blocks']
        with open(file_name, 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(data)

    for i in range(900):
        random.seed(i)
        drop_num=random.randint(0,len(config_dropout)-1)
        emb_num=random.randint(0,len(config_emb_size)-1)
        heads_num=random.randint(0,len(config_num_attn_heads)-1)
        blocks_num=random.randint(0,len(config_n_blocks)-1)
        learning_num=random.randint(0,len(config_learning_rate)-1)


        parser = argparse.ArgumentParser()
        parser.add_argument("--dataset_name", type=str, default=config_dataset_name[3])
        parser.add_argument("--model_name", type=str, default="saint")
        parser.add_argument("--emb_type", type=str, default="qid")
        parser.add_argument("--save_dir", type=str, default="saved_model")
        # parser.add_argument("--learning_rate", type=float, default=1e-5)
        parser.add_argument("--seed", type=int, default=42)
        parser.add_argument("--fold", type=int, default=0)

        parser.add_argument("--dropout", type=float, default=config_dropout[drop_num])
        parser.add_argument("--emb_size", type=int, default=config_emb_size[emb_num])
        parser.add_argument("--learning_rate", type=float, default=config_learning_rate[learning_num])
        parser.add_argument("--num_attn_heads", type=int, default=config_num_attn_heads[heads_num])
        parser.add_argument("--n_blocks", type=int, default=config_n_blocks[blocks_num])
        parser.add_argument("--use_wandb", type=int, default=0)
        parser.add_argument("--add_uuid", type=int, default=1)

    
        args = parser.parse_args()
        params = vars(args)

        main(params)
    # main(params)
best_auc,best_acc, best_epoch = train_model(model, train_loader, valid_loader, num_epochs, opt, ckpt_path, None, None, save_model)
    
    
    data = [best_auc,best_acc,best_epoch, params["dataset_name"], params["dropout"],params["emb_size"],params["learning_rate"],params["num_attn_heads"],params["n_blocks"]]
    file_path = './result/'+params["dataset_name"]+'_saint_2.csv'
    with open(file_path, 'a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(data)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

铁灵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值