import argparse
from argparse import Namespace
from wandb_train import main
from pprint import pprint
import wandb
import random
import os
import csv
if __name__ == "__main__":
config_dataset_name=['assist2009','assist2012','assist2017','nips_task34']
config_dropout=[0.1, 0.2, 0.3,0.4]
config_emb_size=[64,128,256]
config_num_attn_heads=[1,2,4,8]
config_n_blocks=[1,2,4,8]
config_learning_rate=[0.1,0.01,0.001,0.003,0.005]
file_path = './result/'+config_dataset_name[3]+'_saint_2.csv'
if os.path.exists(file_path):
print("文件存在")
else:
file_name='./result/'+config_dataset_name[3]+'_saint_2.csv'
data = ['best_auc','best_acc','best_epoch', 'dataset_name', 'dropout','emb_size','learning_rate','num_attn_heads','n_blocks']
with open(file_name, 'w', newline='') as file:
writer = csv.writer(file)
writer.writerow(data)
for i in range(900):
random.seed(i)
drop_num=random.randint(0,len(config_dropout)-1)
emb_num=random.randint(0,len(config_emb_size)-1)
heads_num=random.randint(0,len(config_num_attn_heads)-1)
blocks_num=random.randint(0,len(config_n_blocks)-1)
learning_num=random.randint(0,len(config_learning_rate)-1)
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default=config_dataset_name[3])
parser.add_argument("--model_name", type=str, default="saint")
parser.add_argument("--emb_type", type=str, default="qid")
parser.add_argument("--save_dir", type=str, default="saved_model")
# parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--fold", type=int, default=0)
parser.add_argument("--dropout", type=float, default=config_dropout[drop_num])
parser.add_argument("--emb_size", type=int, default=config_emb_size[emb_num])
parser.add_argument("--learning_rate", type=float, default=config_learning_rate[learning_num])
parser.add_argument("--num_attn_heads", type=int, default=config_num_attn_heads[heads_num])
parser.add_argument("--n_blocks", type=int, default=config_n_blocks[blocks_num])
parser.add_argument("--use_wandb", type=int, default=0)
parser.add_argument("--add_uuid", type=int, default=1)
args = parser.parse_args()
params = vars(args)
main(params)
# main(params)
best_auc,best_acc, best_epoch = train_model(model, train_loader, valid_loader, num_epochs, opt, ckpt_path, None, None, save_model)
data = [best_auc,best_acc,best_epoch, params["dataset_name"], params["dropout"],params["emb_size"],params["learning_rate"],params["num_attn_heads"],params["n_blocks"]]
file_path = './result/'+params["dataset_name"]+'_saint_2.csv'
with open(file_path, 'a', newline='') as file:
writer = csv.writer(file)
writer.writerow(data)