TrustGeo代码理解(六)main.py(运行模型进行训练和测试)

该博客主要解析TrustGeo的main.py文件,包括导入模块、参数初始化、训练与测试参数设置、模型初始化等环节。文章详细介绍了如何通过命令行参数设定随机种子、模型名、数据集等,以及训练过程中的超参数,如学习率、权重参数等。此外,还涵盖了模型的加载、训练和测试流程,特别强调了模型的前向传播、损失计算和检查点保存策略。
摘要由CSDN通过智能技术生成

代码链接:https://github.com/ICDM-UESTC/TrustGeo

└── TrustGeo
    ├── datasets # 包含3个大规模的真实街道IP地理位置数据集。
    │        |── New_York # 从纽约市收集的街道级IP地理定位数据集,包括91,808个IP地址。
    │        |── Los_Angeles # 从洛杉矶收集的街道级IP地理定位数据集,包括92,804个IP地址。
    │        |── Shanghai # 收集自上海的街道级IP地理定位数据集,包括126,258个IP地址。
    ├── lib # 包含模型(model)实现文件
    │        |── layers.py # 注意力机制的代码。
    │        |── model.py # TrustGeo的核心源代码。
    │        |── sublayers.py # layer.py的支持文件。
    │        |── utils.py # 辅助函数,包括视图融合的代码
    ├── asset # 包含运行模型时保存的检查点和日志
    │        |── log # 包含运行模型时保存的日志
    │        |── model # 包含运行模型时保存的检查点
    ├── preprocess.py # 预处理数据集并为模型运行执行IP聚类
    ├── main.py # 运行模型进行训练和测试
    ├── test.py #加载检查点,然后测试
    └── README.md

一、导入各种模块和数据库

# -*- coding: utf-8 -*-
import torch.nn

from lib.utils import *
import argparse, os
import numpy as np
import random
from lib.model import *
import copy
from thop import profile
import pandas as pd

整体功能是准备运行一个 PyTorch 深度学习模型的环境,具体的功能实现需要查看 lib.utils、lib.model 中的代码,以及整个文件的后续部分。

该块代码实现部分一致

RIPGeo代码理解(六)main.py(运行模型进行训练和测试)-CSDN博客

不同之处在于:

1、from thop import profile从 thop 模块中导入 profile 函数,该函数用于计算 PyTorch 模型的 FLOPs(浮点运算数)和参数数量。(在代码链接中没有找到)

2、import pandas as pd:导入 Pandas 库,用于数据处理和分析,通常用于处理表格型数据。

二、参数初始化(通过命令行参数)

parser = argparse.ArgumentParser()
# parameters of initializing
parser.add_argument('--seed', type=int, default=2022, help='manual seed')
parser.add_argument('--model_name', type=str, default='TrustGeo')
parser.add_argument('--dataset', type=str, default='New_York', choices=["Shanghai", "New_York", "Los_Angeles"],
                    help='which dataset to use')

这部分代码的目的是通过命令行参数设置一些初始化的参数,例如随机数种子、模型名称和数据集名称。这使得在运行脚本时可以通过命令行参数来指定这些参数的值。

该块代码实现部分一致

RIPGeo代码理解(六)main.py(运行模型进行训练和测试)-CSDN博客

不同之处在于:

1、parser.add_argument('--seed', type=int, default=2022, help='manual seed'):添加一个命令行参数,名称为 '--seed',表示随机数种子,类型为整数,默认值为 2022,help 参数是在命令行中输入 --help 时显示的帮助信息。(RIPGeo中默认值为 1024

2、parser.add_argument('--model_name', type=str, default='TrustGeo'):添加一个命令行参数,名称为 '--model_name',表示模型的名称,类型为字符串,默认值为 'TrustGeo'。(RIPGeo中默认值为RIPGeo

三、训练过程参数设置

# parameters of training
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--lambda1', type=float, default=7e-3)
parser.add_argument('--lr', type=float, default=5e-3)
parser.add_argument('--harved_epoch', type=int, default=5) 
parser.add_argument('--early_stop_epoch', type=int, default=50)
parser.add_argument('--saved_epoch', type=int, default=200)  

这部分代码的目的是设置一些训练过程中的超参数,例如优化器的动量参数、学习率、权重参数等。这些参数在训练过程中会影响模型的更新和收敛速度。

该块代码实现部分一致

RIPGeo代码理解(六)main.py(运行模型进行训练和测试)-CSDN博客

不同之处在于:

1、parser.add_argument('--lambda1', type=float, default=7e-3):添加一个命令行参数,名称为 '--lambda1',表示某个权重参数,类型为浮点数,默认值为 7e-3。

5、parser.add_argument('--lr', type=float, default=5e-3):添加一个命令行参数,名称为 '--lr',表示学习率,类型为浮点数,默认值为 5e-3。 (RIPGeo中默认值为2e-3

四、模型参数设置

# parameters of model
parser.add_argument('--dim_in', type=int, default=30, choices=[51, 30], help="51 if Shanghai / 30 else")

opt = parser.parse_args()
print("Learning rate: ", opt.lr)
print("Dataset: ", opt.dataset)

这部分代码的目的是解析命令行参数,并打印出学习率和数据集名称。--dim_in 参数用于指定输入维度,可以选择是 51 或者 30。

该块代码实现完全一致

RIPGeo代码理解(六)main.py(运行模型进行训练和

  • 9
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值