TrustGeo代码理解(六)main.py中的__main__()中的一些理解

本文深入解析TrustGeo代码中的main.py,重点在于训练和测试时的损失计算。训练时采用NIG_loss结合三种视角进行优化,旨在模型多视角学习。测试时则计算预测与真实位置的欧氏距离,直接反映模型预测性能。设计如此,旨在确保模型在训练过程中的有效优化及测试时的直观评估。
摘要由CSDN通过智能技术生成

TrustGeo代码理解(六)main.py(运行模型进行训练和测试)-CSDN博客中的十一部分中有以下代码:

if __name__ == '__main__':
    train_data, test_data = get_data_generator(opt, train_data, test_data, normal=2)
 
    log_path = f"asset/log"
    if not os.path.exists(log_path):
        os.mkdir(log_path)
        
    f = open(f"asset/log/{opt.dataset}.txt", 'a')
    f.write(f"*********{opt.dataset}*********\n")
    f.write("dim_in="+str(opt.dim_in)+", ")
    f.write("early_stop_epoch="+str(opt.early_stop_epoch)+", ")
    f.write("harved_epoch="+str(opt.harved_epoch)+", ")
    f.write("saved_epoch="+str(opt.saved_epoch)+", ")
    f.write("lambda="+str(opt.lambda1)+", ")
    f.write("lr="+str(opt.lr)+", ")
    f.write("model_name="+opt.model_name+", ")
    f.write("seed="+str(opt.seed)+",")
    f.write("\n")
    f.close()
 
    # train
    losses = [np.inf]
    no_better_epoch = 0
    early_stop_epoch = 0
 
    for epoch in range(2000):
        print("epoch {}.    ".format(epoch))
        beta = min([(epoch * 1.) / max([100, 1.]), 1.])
        total_loss, total_mae, train_num, total_data_perturb_loss = 0, 0, 0, 0
        model.train()
        for i in range(len(train_data)):
            lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay, y_max, y_min = train_data[i]["lm_X"], \
                                                                       train_data[i]["lm_Y"], \
                                                                       train_data[i]["tg_X"], \
                                                                       train_data[i]["tg_Y"], \
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值