在TrustGeo代码理解(六)main.py(运行模型进行训练和测试)-CSDN博客中的十一部分中有以下代码:
if __name__ == '__main__':
train_data, test_data = get_data_generator(opt, train_data, test_data, normal=2)
log_path = f"asset/log"
if not os.path.exists(log_path):
os.mkdir(log_path)
f = open(f"asset/log/{opt.dataset}.txt", 'a')
f.write(f"*********{opt.dataset}*********\n")
f.write("dim_in="+str(opt.dim_in)+", ")
f.write("early_stop_epoch="+str(opt.early_stop_epoch)+", ")
f.write("harved_epoch="+str(opt.harved_epoch)+", ")
f.write("saved_epoch="+str(opt.saved_epoch)+", ")
f.write("lambda="+str(opt.lambda1)+", ")
f.write("lr="+str(opt.lr)+", ")
f.write("model_name="+opt.model_name+", ")
f.write("seed="+str(opt.seed)+",")
f.write("\n")
f.close()
# train
losses = [np.inf]
no_better_epoch = 0
early_stop_epoch = 0
for epoch in range(2000):
print("epoch {}. ".format(epoch))
beta = min([(epoch * 1.) / max([100, 1.]), 1.])
total_loss, total_mae, train_num, total_data_perturb_loss = 0, 0, 0, 0
model.train()
for i in range(len(train_data)):
lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay, y_max, y_min = train_data[i]["lm_X"], \
train_data[i]["lm_Y"], \
train_data[i]["tg_X"], \
train_data[i]["tg_Y"], \