libcity笔记:详细流程(以DeepMove为例)

0 前置操作

这边我选择了gowalla的前1000条数据做例子:

0.1 生成样例dyna

import pandas as pd
geo=pd.read_csv('/home_nfs/liushuai/Bigscity-LibCity/raw_data/gowalla_test/gowalla.dyna')

geo_tst=geo.iloc[:1000,:]
geo_tst

 

geo_tst.to_csv('/home_nfs/liushuai/Bigscity-LibCity/raw_data/gowalla_test/gowalla.dyna', index=False)

 0.2 生成相应geo

geo=pd.read_csv('/home_nfs/liushuai/Bigscity-LibCity/raw_data/gowalla/gowalla.geo')
geo=geo[geo['geo_id'].isin(set(geo_tst.location))]
geo.to_csv('/home_nfs/liushuai/Bigscity-LibCity/raw_data/gowalla_test/gowalla.geo')

 

0.3 修改libcity/config/task_config.json

1 主调用

 python run_model.py --task traj_loc_pred --model DeepMove --dataset gowalla --batch_size=5
  • 有task、dataset、model三个必须命令行参数
  • batch_size一个可选命令行参数
  • 没有confg_file 

加载所有参数

1.1 libcity/utils/argument_list.py/str2bool

  • 将字符串表示的布尔值转换为 Python 中的布尔值。
    • 首先检查输入的参数是否已经是布尔值类型,如果是,则直接返回该值,无需转换。

    • if s.lower() in ('yes', 'true')::检查字符串是否是 'yes''true',如果是,则返回 True

    • elif s.lower() in ('no', 'false')::检查字符串是否是 'no''false',如果是,则返回 False

    • else::如果字符串既不是 'yes'/'true' 也不是 'no'/'false',则抛出 argparse.ArgumentTypeError 异常,表示期望一个布尔值。

1.2 libcity/utils/argument_list.py/add_general_args

2 libcity/pipeline/pipeline.py/run_model

2.1 libcity/config/config_parser.py/ConfigParser

2.1.1 构造函数

libcity笔记:libcity/config/config_parser.py/ConfigParser-CSDN博客

2.1.2 get

libcity笔记:libcity/config/config_parser.py/ConfigParser-CSDN博客

2.2 libcity/utils/utils.py/get_logger

libcity笔记:libcity/utils/utils.py-CSDN博客

 然后连着两行logger.info

2.3 set_random_seed

libcity笔记:libcity/utils/utils.py-CSDN博客

2.4 get_dataset

libcity 笔记:libcity/data/utils.py-CSDN博客

得到相应的TrajectoryDataset

  • 得到dataset_cache和cut_traj的json文件,缓存处理的轨迹数据

2.5 dataset.get_data()

2.5.1 cutter_filter

  • 由于之前没有dataset_cache和cut_traj的缓存json文件,先调用cuttter_filter【 cut_data = self.cutter_filter()】
    • 得到的结果,是一个字典,key是user_id,value是一系列二维数组组成的列表,每个二维数组的每一行是“dyna_id    type    time    entity_id    location”

2.5.2 encode_traj

  • 【encoded_data = self.encode_traj(cut_data)】
    • {
                  'data_feature': self.encoder.data_feature,
                  'pad_item': self.encoder.pad_item,
                  'encoded_data': encoded_data
              }的一个字典,其中:

2.5.2  divide_data

  • train_data, eval_data, test_data = self.divide_data()
    • 其中每一个元素也就是

 2.5.3 generate_dataloader_pad

  • 传入的参数
    • train_data, eval_data, test_data
    • self.encoder.feature_dict,
    • self.config['batch_size'],
    • self.config['num_workers'],
    • self.pad_item,
    • self.encoder.feature_max_len
      • 没有设置,就是默认的{}
  • libcity 笔记:libcity/data/utils.py-CSDN博客

2.6 get_data_feature

Libcity笔记:libcity/data/dataset/trajectory_encoder/standard_trajectory_encoder.py-CSDN博客

2.7 get_model

从ibcity/model/trajectory_loc_prediction/DeepMove.py 中生成 DeepMove类

2.8 get_executor

2.8.1 get_evaluator

TrajLocPredEvaluator

libcity笔记:libcity/evaluator/traj_loc_pred_evaluator.py-CSDN博客

2.8.2 self.metrics

这里是'Recall@5'

2.8.3 获取优化器和调度器

Adam+ReduceLROnPlateau

2.9 executor.train

  • 对每一个epoch:

2.9.1 run

executor.run(train_dataloader, self.model,
                                            self.config['learning_rate'], self.config['clip'])

loss_func是Deepmove的calculate_loss

然后就是逐batch训练模型

2.9.2 剩余操作

  • 计算validation loss
  • 更新学习率
  • 判断是否需要早停。。。

2.10 executor.evaluate(test_data)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值