0 前置操作
这边我选择了gowalla的前1000条数据做例子:
0.1 生成样例dyna
import pandas as pd
geo=pd.read_csv('/home_nfs/liushuai/Bigscity-LibCity/raw_data/gowalla_test/gowalla.dyna')
geo_tst=geo.iloc[:1000,:]
geo_tst
geo_tst.to_csv('/home_nfs/liushuai/Bigscity-LibCity/raw_data/gowalla_test/gowalla.dyna', index=False)
0.2 生成相应geo
geo=pd.read_csv('/home_nfs/liushuai/Bigscity-LibCity/raw_data/gowalla/gowalla.geo')
geo=geo[geo['geo_id'].isin(set(geo_tst.location))]
geo.to_csv('/home_nfs/liushuai/Bigscity-LibCity/raw_data/gowalla_test/gowalla.geo')
0.3 修改libcity/config/task_config.json
1 主调用
python run_model.py --task traj_loc_pred --model DeepMove --dataset gowalla --batch_size=5
- 有task、dataset、model三个必须命令行参数
- batch_size一个可选命令行参数
- 没有confg_file
加载所有参数
1.1 libcity/utils/argument_list.py/str2bool
- 将字符串表示的布尔值转换为 Python 中的布尔值。
-
首先检查输入的参数是否已经是布尔值类型,如果是,则直接返回该值,无需转换。
-
if s.lower() in ('yes', 'true'):
:检查字符串是否是'yes'
或'true'
,如果是,则返回True
。 -
elif s.lower() in ('no', 'false'):
:检查字符串是否是'no'
或'false'
,如果是,则返回False
。 -
else:
:如果字符串既不是'yes'/'true'
也不是'no'/'false'
,则抛出argparse.ArgumentTypeError
异常,表示期望一个布尔值。
-
1.2 libcity/utils/argument_list.py/add_general_args
2 libcity/pipeline/pipeline.py/run_model
2.1 libcity/config/config_parser.py/ConfigParser
2.1.1 构造函数
libcity笔记:libcity/config/config_parser.py/ConfigParser-CSDN博客
2.1.2 get
libcity笔记:libcity/config/config_parser.py/ConfigParser-CSDN博客
2.2 libcity/utils/utils.py/get_logger
libcity笔记:libcity/utils/utils.py-CSDN博客
然后连着两行logger.info
2.3 set_random_seed
libcity笔记:libcity/utils/utils.py-CSDN博客
2.4 get_dataset
libcity 笔记:libcity/data/utils.py-CSDN博客
得到相应的TrajectoryDataset
- 得到dataset_cache和cut_traj的json文件,缓存处理的轨迹数据
2.5 dataset.get_data()
2.5.1 cutter_filter
- 由于之前没有dataset_cache和cut_traj的缓存json文件,先调用cuttter_filter【 cut_data = self.cutter_filter()】
- 得到的结果,是一个字典,key是user_id,value是一系列二维数组组成的列表,每个二维数组的每一行是“dyna_id type time entity_id location”
2.5.2 encode_traj
- 【encoded_data = self.encode_traj(cut_data)】
- {
'data_feature': self.encoder.data_feature,
'pad_item': self.encoder.pad_item,
'encoded_data': encoded_data
}的一个字典,其中: -
- {
2.5.2 divide_data
- train_data, eval_data, test_data = self.divide_data()
- 其中每一个元素也就是
- 其中每一个元素也就是
2.5.3 generate_dataloader_pad
- 传入的参数
- train_data, eval_data, test_data
- self.encoder.feature_dict,
- self.config['batch_size'],
- self.config['num_workers'],
- self.pad_item,
- self.encoder.feature_max_len
- 没有设置,就是默认的{}
- libcity 笔记:libcity/data/utils.py-CSDN博客
2.6 get_data_feature
Libcity笔记:libcity/data/dataset/trajectory_encoder/standard_trajectory_encoder.py-CSDN博客
2.7 get_model
从ibcity/model/trajectory_loc_prediction/DeepMove.py 中生成 DeepMove类
2.8 get_executor
2.8.1 get_evaluator
TrajLocPredEvaluator
libcity笔记:libcity/evaluator/traj_loc_pred_evaluator.py-CSDN博客
2.8.2 self.metrics
这里是'Recall@5'
2.8.3 获取优化器和调度器
Adam+ReduceLROnPlateau
2.9 executor.train
- 对每一个epoch:
2.9.1 run
executor.run(train_dataloader, self.model,
self.config['learning_rate'], self.config['clip'])
loss_func是Deepmove的calculate_loss
然后就是逐batch训练模型
2.9.2 剩余操作
- 计算validation loss
- 更新学习率
- 判断是否需要早停。。。