论文辅助笔记:LLM-MOB代码解读

论文笔记 Where Would I Go Next? Large Language Models as Human Mobility Predictor-CSDN博客

1 主函数

1.1 导入库

import os
import pickle
import time
import ast
import logging
from datetime import datetime
import pandas as pd
from openai import OpenAI

client = OpenAI(
        api_key=...
    )

1.2 参数读取

dataname = "geolife"  
# 数据集名称
    
num_historical_stay = 40  
# 长期mobility 的跨度
    
num_context_stay = 5  
# 短期mobility的跨度

top_k = 10  
# 输出location的数量

with_time = False  
# 是否将目标stay的时间信息融入进来

sleep_single_query = 0.1 
''' 
the sleep time between queries 
after the recent updates, the reliability of the API is greatly improved
so we can reduce the sleep time
'''


sleep_if_crash = 1  
'''
the sleep time if the server crashes
'''


output_dir = f"output/{dataname}/top10_wot"  
'''
the output path
'''
    
log_dir = f"logs/{dataname}/top10_wot"  
'''
the log dir
'''

1.3 读取参数

tv_data, test_file = get_dataset(dataname)
#Number of total test sample:  3459

'''
这个数量是比f"data/{dataname}/{dataname}_test.csv" 要少的
'''

1.4 日志文件生成

logger = get_logger('my_logger', log_dir=log_dir)

1.5 user id 提取

uid_list = get_unqueried_user(dataname, output_dir)
print(f"uid_list: {uid_list}")

'''
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
Number of the remaining id: 45
uid_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
'''

 1.6 生成query

query_all_user(
    client, 
    dataname, 
    uid_list, 
    logger, 
    tv_data, 
    num_historical_stay, 
    num_context_stay,
    test_file, 
    output_dir=output_dir, 
    top_k=top_k, 
    is_wt=with_time,
    sleep_query=sleep_single_query, 
    sleep_crash=sleep_if_crash)

2 get_dataset函数

def get_dataset(dataname):
    
    # Get training and validation set and merge them
    train_data = pd.read_csv(f"data/{dataname}/{dataname}_train.csv")
    valid_data = pd.read_csv(f"data/{dataname}/{dataname}_valid.csv")
    #读取训练+验证集

    # Get test data
    with open(f"data/{dataname}/{dataname}_testset.pk", "rb") as f:
        test_file = pickle.load(f)  # test_file is a list of dict
    #测试集

    # merge train and valid data
    tv_data = pd.concat([train_data, valid_data], ignore_index=True)
    tv_data.sort_values(['user_id', 'start_day', 'start_min'], inplace=True)
    if dataname == 'geolife':
        tv_data['duration'] = tv_data['duration'].astype(int)
    #合并训练+验证集

    print("Number of total test sample: ", len(test_file))
    return tv_data, test_file

 3 get_logger

def get_logger(logger_name, log_dir='logs/'):
    # Create log dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Create a logger instance
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.DEBUG)
    #创建一个日志记录器实例,并将其命名为 logger_name,并设置日志记录器的级别为 DEBUG

    # Create a console handler and set its log level
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.DEBUG)
    # 创建一个控制台处理器,其作用是将接收到的日志消息输出到控制台

    # Create a file handler and set its log level
    current_datetime = datetime.now()
    formatted_datetime = current_datetime.strftime("%Y%m%d_%H%M%S")
    #获取当前日期和时间,格式化为 "YYYYMMDD_HHMMSS" 形式,这部分用于生成日志文件名
    log_file = 'log_file' + formatted_datetime + '.log'
    #将格式化的时间字符串添加到 "log_file" 后面,构成日志文件名,例如 log_file20230424_153000.log
    log_file_path = os.path.join(log_dir, log_file)
    #使用 os.path.join(log_dir, log_file) 创建完整的日志文件路径

    
    file_handler = logging.FileHandler(log_file_path)
    file_handler.setLevel(logging.DEBUG)
    #创建一个文件处理器,用于将日志消息写入到指定的文件中

    

    # Create a formatter and add it to the handlers
    formatter = logging.Formatter('%(message)s')
    #创建一个格式器 formatter,设置日志格式为仅包含消息体,即 '%(message)s'
    
    console_handler.setFormatter(formatter)
    file_handler.setFormatter(formatter)
    #将控制台处理器和文件处理器添加到日志记录器实例上

    # Add the handlers to the logger
    logger.addHandler(console_handler)
    logger.addHandler(file_handler)

    return logger

4 get_unqueried_user

提取数据集对应的user id

def get_unqueried_user(dataname, output_dir='output/'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)


    if dataname == "geolife":
        all_user_id = [i+1 for i in range(45)]
    elif dataname == "fsq":
        all_user_id = [i+1 for i in range(535)]


    processed_id = [int(file.split('.')[0]) for file in os.listdir(output_dir) if file.endswith('.csv')]
    remain_id = [i for i in all_user_id if i not in processed_id]
    print(remain_id)
    print(f"Number of the remaining id: {len(remain_id)}")
    return remain_id

5 query_all_user

def query_all_user(client, dataname, uid_list, logger, train_data, num_historical_stay,
                   num_context_stay, test_file, top_k, is_wt, output_dir, sleep_query, sleep_crash):
    for uid in uid_list:
        logger.info(f"=================Processing user {uid}==================")
        user_train = get_user_data(train_data, uid, num_historical_stay, logger)
        #当前研究的uid的长期历史mobility(M条)




        historical_data, predict_X, predict_y = organise_data(dataname, user_train, test_file, uid, logger, num_context_stay)
        '''
            返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth



             每一条记录的格式是:('09:08 PM', 'Wednesday', 466, 10),
        '''
        
        single_user_query(client, dataname, uid, historical_data, predict_X, predict_y, logger, top_k=top_k, 
                          is_wt=is_wt, output_dir=output_dir, sleep_query=sleep_query, sleep_crash=sleep_crash)

5.1 get_user_data

提取当前研究的uid的长期历史mobility(M条)

def get_user_data(train_data, uid, num_historical_stay, logger):
    user_train = train_data[train_data['user_id']==uid]
    #找到当下研究的user id对应的所有record

    logger.info(f"Length of user {uid} train data: {len(user_train)}")
    #user id一共多少条记录

    user_train = user_train.tail(num_historical_stay)
  

    logger.info(f"Number of user historical stays: {len(user_train)}")
    #long term mobility需要考虑多长的历史轨迹
    return user_train

5.2 organise_data

    返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth

每一条记录的格式是:('09:08 PM', 'Wednesday', 466, 10),

def organise_data(dataname, user_train, test_file, uid, logger, num_context_stay=5):
    # Use another way of organising data
    # user_train只是临近的M个record
    historical_data = []

    if dataname == 'geolife':
        for _, row in user_train.iterrows():
            historical_data.append(
                (convert_to_12_hour_clock(int(row['start_min'])), 
                int2dow(row['weekday']),
                int(row['duration']),
                row['location_id'])
                )
    elif dataname == 'fsq':
        for _, row in user_train.iterrows():
            historical_data.append(
                (convert_to_12_hour_clock(int(row['start_min'])),
                int2dow(row['weekday']),
                row['location_id'])
                )
    '''
每次append如下内容

time-of-day:时间转化成几点几分 AM/PM的形式
day-of-week:日子转化成星期几的形式
duration   :持续时间类型转化为整型
location id:location 对应的id

eg,
[('09:08 PM', 'Wednesday', 466, 10),
 ('04:58 AM', 'Thursday', 187, 17),
 ('08:07 AM', 'Thursday', 146, 1),
 ('10:35 AM', 'Thursday', 193, 17),
 ('01:54 PM', 'Thursday', 556, 10)]
    '''

    logger.info(f"historical_data: {historical_data}")
    logger.info(f"Number of historical_data: {len(historical_data)}")

    # Get user ith test data
    list_user_dict = []
    for i_dict in test_file:
        if dataname == 'geolife':
            i_uid = i_dict['user_X'][0]
        elif dataname == 'fsq':
            i_uid = i_dict['user_X']
        if i_uid == uid:
            list_user_dict.append(i_dict)
    #测试集中和user id 相同的 record 放入 list_user_dict
    #这个user id 需要测试的轨迹

    predict_X = []
    predict_y = []
    for i_dict in list_user_dict:
        construct_dict = {}
        if dataname == 'geolife':
            context = list(zip([convert_to_12_hour_clock(int(item)) for item in i_dict['start_min_X'][-num_context_stay:]], 
                            [int2dow(i) for i in i_dict['weekday_X'][-num_context_stay:]], 
                            [int(i) for i in i_dict['dur_X'][-num_context_stay:]], 
                            i_dict['X'][-num_context_stay:]))
        elif dataname == 'fsq':
            context = list(zip([convert_to_12_hour_clock(int(item)) for item in i_dict['start_min_X'][-num_context_stay:]], 
                            [int2dow(i) for i in i_dict['weekday_X'][-num_context_stay:]], 
                            i_dict['X'][-num_context_stay:]))
        '''
    只看geolife的话,context是一个有五个元素的list
    每个元素和前面append到historical_data的格式是一样的
        '''


        target = (convert_to_12_hour_clock(int(i_dict['start_min_Y'])), int2dow(i_dict['weekday_Y']), None, "<next_place_id>")
        #('12:36 AM', 'Friday', None, '<next_place_id>')

        construct_dict['context_stay'] = context
        construct_dict['target_stay'] = target
        #构造输入,临近的N个location+目标的时刻和星期

        predict_y.append(i_dict['Y'])
        #ground-truth的station id
        predict_X.append(construct_dict)
        #构造的输入
    
    

    logger.info(f"Number of predict_data: {len(predict_X)}")
    #这个user_id在test 数据中有多少条记录

    logger.info(f"predict_y: {predict_y}")
    
    logger.info(f"Number of predict_y: {len(predict_y)}")
    #虽然这个数量应该和predict_X的一样

    return historical_data, predict_X, predict_y
    '''
    返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth
    '''

5.2.1  convert_to_12_hour_clock

#转化成几点几分 AM/PM的形式

def convert_to_12_hour_clock(minutes):
    #原始数据的minutes 从这一天的0点算起,第几分钟
    if minutes < 0 or minutes >= 1440:
        return "Invalid input. Minutes should be between 0 and 1439."

    hours = minutes // 60
    minutes %= 60

    period = "AM"
    if hours >= 12:
        period = "PM"

    if hours == 0:
        hours = 12
    elif hours > 12:
        hours -= 12

    return f"{hours:02d}:{minutes:02d} {period}"
    #转化成几点几分 AM/PM的形式

5.2.2 int2dow

#转化成星期几的形式

def int2dow(int_day):
    tmp = {0: 'Monday', 1: 'Tuesday', 2: 'Wednesday',
           3: 'Thursday', 4: 'Friday', 5: 'Saturday', 6: 'Sunday'}
    return tmp[int_day]

5.3 single_user_query

保存location的预测结果

def single_user_query(client, dataname, uid, historical_data, predict_X, predict_y,logger, top_k, is_wt, output_dir, sleep_query, sleep_crash):
    # Initialize variables
    total_queries = len(predict_X)
    logger.info(f"Total_queries: {total_queries}")
    #这个user id 一共有多少条查询

    processed_queries = 0
    current_results = pd.DataFrame({
        'user_id': None,
        'ground_truth': None,
        'prediction': None,
        'reason': None
    }, index=[])

    out_filename = f"{uid:02d}" + ".csv"
    out_filepath = os.path.join(output_dir, out_filename)



    try:
        # Attempt to load previous results if available
        current_results = load_results(out_filepath)
        processed_queries = len(current_results)
        logger.info(f"Loaded {processed_queries} previous results.")
    except FileNotFoundError:
        logger.info("No previous results found. Starting from scratch.")
    '''
    读取这个用户已经处理了的预测结果
    '''



    # Process remaining queries
    for i in range(processed_queries, total_queries):
        #预测这个用户剩余的查询

        logger.info(f'The {i+1}th sample: ')
        if dataname == 'geolife':
            if is_wt is True:
                if top_k == 1:
                    completions = single_query_top1(client, historical_data, predict_X[i])
                elif top_k == 10:
                    completions = single_query_top10(client, historical_data, predict_X[i])
                else:
                    raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")
            else:
                if top_k == 1:
                    completions = single_query_top1_wot(client, historical_data, predict_X[i])
                elif top_k == 10:
                    completions = single_query_top10_wot(client, historical_data, predict_X[i])
                else:
                    raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")
        elif dataname == 'fsq':
            if is_wt is True:
                if top_k == 1:
                    completions = single_query_top1_fsq(client, historical_data, predict_X[i])
                elif top_k == 10:
                    completions = single_query_top10_fsq(client, historical_data, predict_X[i])
                else:
                    raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")
            else:
                if top_k == 1:
                    completions = single_query_top1_wot_fsq(client, historical_data, predict_X[i])
                elif top_k == 10:
                    completions = single_query_top10_wot_fsq(client, historical_data, predict_X[i])
                else:
                    raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")
        '''
        gpt针对不同情况的完整response
        '''




        response = completions.choices[0].message.content
        #gpt的response

        # Log the prediction results and usage.
        logger.info(f"Pred results: {response}")
        logger.info(f"Ground truth: {predict_y[i]}")
        logger.info(dict(completions).get('usage'))
        #使用的token数



        try:
            res_dict = ast.literal_eval(response)  
            # 解析gpt的输出,至字典的形式

            if top_k != 1:
                res_dict['prediction'] = str(res_dict['prediction'])

            res_dict['user_id'] = uid
            res_dict['ground_truth'] = predict_y[i]
        except Exception as e:
            res_dict = {'user_id': uid, 'ground_truth': predict_y[i], 'prediction': -100, 'reason': None}
            logger.info(e)
            logger.info(f"API request failed for the {i+1}th query")
            # time.sleep(sleep_crash)
            #如果上述任何一步出问题,说明预测失败
        finally:
            new_row = pd.DataFrame(res_dict, index=[0])  
            # A dataframe with only one record
            #当前这个location 预测的dataframe

            current_results = pd.concat([current_results, new_row], ignore_index=True)  
            # Add new row to the current df
            #这个user id的累计 预测location

    # Save the current results
    current_results.to_csv(out_filepath, index=False)
    #save_results(current_results, out_filename)
    logger.info(f"Saved {len(current_results)} results to {out_filepath}")
    #保存这个user的location 预测结果

    # Continue processing remaining queries
    if len(current_results) < total_queries:
        #remaining_predict_X = predict_X[len(current_results):]
        #remaining_predict_y = predict_y[len(current_results):]
        #remaining_queries = queries[len(current_results):]
        logger.info("Restarting queries from the last successful point.")
        single_user_query(client, dataname, uid, historical_data, predict_X, predict_y,
                          logger, top_k, is_wt, output_dir, sleep_query, sleep_crash)

5.3.1 load_results

就是读取这个user 之前已经保存的预测记录

def load_results(filename):
    # Load previously saved results from a CSV file    
    results = pd.read_csv(filename)
    return results

5.3.2 single_query_top1_wot

5.3.3 single_query_top1

5.3.4 single_query_top10

top10的区别不大(几乎一模一样),就多了这样一句话:

5.3.5 get_chat_completion

提供gpt prompt获得相应的结果

def get_chat_completion(client, prompt, model="gpt-3.5-turbo-0613", json_mode=False, max_tokens=1200):
    """
    args:
        client: the openai client object (new in 1.x version)
        prompt: the prompt to be completed
        model: specify the model to use
        json_mode: whether return the response in json format (new in 1.x version)
    """
    messages = [{"role": "user", "content": prompt}]
    if json_mode:
        completion = client.chat.completions.create(
            model=model,
            response_format={"type": "json_object"},
            messages=messages,
            temperature=0,  # the degree of randomness of the model's output
            max_tokens=max_tokens  # the maximum number of tokens to generate
        )
    else:
        completion = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=0,
            max_tokens=max_tokens
        )
    # res_content = response.choices[0].message["content"]
    # token_usage = response.usage
    return completion

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值