代码详解(小白版)-STResNet for Citywide Crowd Flows Prediction

论文:Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction

代码我看不太懂……我找解析,找不到……

所以,我根据自己的理解写了一下pytorch版代码的详解:

代码完整复现请见这位大佬的,我只是自己分析了代码。

STResNet-PyTorch: Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction pytorch版本 (gitee.com)

看我解析的时候一定要对照源码思考!!!

放心我是小白,我的基础不是很好,所以我也写的详细,对大家来说肯定是比较容易理解滴。

本文先分析预处理包preprocessing中的三个部分:MaxMinNormalization.py、timestamp.py、STMatrix.py

一、MaxMinNormalization.py

1.公式

原来的公式只有x* = (x-min)/(max-min)

但最后的范围是[-1,1],加了一步X=X*2 -1

只看这个MaxMinNormalization归一化的话,只有x* = (x-min)/(max-min)这一步,但得到的范围就是[0,1]了

2.def inverse_transform(self, X):

X = (X + 1.) / 2.

    X = 1. * X * (self._max - self._min) + self._min

return X

这个函数就是把归一化的公式又倒回去了

3.MaxMinNormalization.py源代码

# Acknowledgement: This code is taken from https://github.com/TolicWang/DeepST
import numpy as np


np.random.seed(1337)  # for reproducibility


class MinMaxNormalization(object):
    """
        MinMax Normalization --> [-1, 1]
        x = (x - min) / (max - min).
        x = x * 2 - 1
    """

    def __init__(self):
        pass

    def fit(self, X):
        self._min = X.min()
        self._max = X.max()
        print("min:", self._min, "max:", self._max)

    def transform(self, X):
        X = 1. * (X - self._min) / (self._max - self._min)
        X = X * 2. - 1.
        return X

    def fit_transform(self, X):
        self.fit(X)
        return self.transform(X)

    def inverse_transform(self, X):
        X = (X + 1.) / 2.
        X = 1. * X * (self._max - self._min) + self._min
        return X

二、timestamp.py

1.函数string2timestamp

原论文在Table1中写到对于TaxiBJ数据,采样间隔是30min,采样数据大小是(32,32)

一天24小时,每隔30min采样一次就是每天采样48次,也就是string2timestamp(strings, T=48)出现的参数T=48

time_per_slot=0.5,每次采样的时间是0.5h

num_per_T=2,每个小时中有两次采样

输入的str格式是[b'2013070101', b'2013070102'],下标为0123代表年,45代表月,67代表日,89代表这天的第几次采样,下标89所组成的二位数最大是48.比如,'2013070101'代表‘2013.7.1的00:00点’,'2013070105'代表‘2013.7.1的02:00点’

slot =int(t[8:]) - 1其中减掉1的意义?当采样次数为01时,实际的时间为半夜的00:00点

hour=int(slot * time_per_slot)现在的时间,就等于采样的次数乘上每次采样的时间.如果时间是02:30,按照这个式子算出来是2.5,所以要取整

minute=(slot % num_per_T) * int(60.0 * time_per_slot),%表示采样次数整除2后取余,也就是看现在的时间是一个整的小时02:00之类的还是一个半的小时02:30.如果是整的小时,前面式子求出来是0,整个式子取0;如果是半的小时,前面的式子求出来是1,后面式子求出来是30,最后就是minute = 30

整个函数最后返回值[Timestamp('2013-07-01 00:00:00'), Timestamp('2013-07-01 00:30:00')]

2.函数timestamp2vec

补充知识:

(1)time.striptime()根据指定的格式,把一个时间字符串解析为时间元组。格式化符号有很多,代码中出现的%Y表示四位数的年份(两位数年份用y%),%m表示月份,%d表示月内的一天。举个例子,import time,time_a = time.strptime("20240924","%Y%m%d"),print(time_a)得time.struct_time(tm_year=2024, tm_mon=9, tm_mday=24, tm_hour=0, tm_min=0, tm_sec=0, tm_wday=1, tm_yday=268, tm_isdst=-1)。其中,tm_wday表示每周的第几天,取值是0-6,这里取的1表示是周二。tm_isdst表示是否是夏令时,取值是1是、0不是、-1未知,默认是-1.

(2)np.asarray()主要是要与np.array()区分,都是返回一个ndarray(n维数组对象)。前者np.asarray()不会改变地址。也就是说,假设有个数组a,b = np.asarray(a),c =np.array(a),那么更改a后,再输出bc,会发现b跟a一样变了,c还是原来的a。

这里vec就是取出了每个时间是周几的信息。

然后对vec遍历,每次遍历:先建一个数组v=[0,0,0,0,0,0,0]有七个位置,对应着每星期的七天。拿出来的这个数据是周几,第几位0就变成1。比如,遍历的i=3,数组现在就成了v=[0,0,1,0,0,0,0]。像是叫独热编码(把离散的数据数字化)还是稀疏矩阵,我猜的。然后有个if如果是周末的话,列表v后面再加个0,如果是工作日,就加个1.比如这里是周三,最后v=[0,0,1,0,0,0,0,1].最后把v这个列表作为一个元素加到列表ret中。

这样ret就是列表嵌套列表了,像是二维的。

3.if__name__ ==__main__:

对于其中的内容,如果这个.py文件作为包被导入别的文件中,这些内容不会被执行,只在本原文件中才被执行。

4.timestamp.py源代码

# Acknowledgement: This code is taken from https://github.com/TolicWang/DeepST
import time
import pandas as pd
import numpy as np
from datetime import datetime


def string2timestamp(strings, T=48):
    """
    :param strings:
    :param T:
    :return:
    example:
    str = [b'2013070101', b'2013070102']
    print(string2timestamp(str))
    [Timestamp('2013-07-01 00:00:00'), Timestamp('2013-07-01 00:30:00')]
    """
    timestamps = []

    time_per_slot = 24.0 / T
    num_per_T = T // 24
    for t in strings:
        year, month, day, slot = int(t[:4]), int(t[4:6]), int(t[6:8]), int(t[8:]) - 1
        timestamps.append(pd.Timestamp(datetime(year, month, day, hour=int(slot * time_per_slot),
                                                minute=(slot % num_per_T) * int(60.0 * time_per_slot))))

    return timestamps
def timestamp2vec(timestamps):
    """
    :param timestamps:
    :return:
    exampel:
    [b'2018120505', b'2018120106']
    #[[0 0 1 0 0 0 0 1]  
     [0 0 0 0 0 1 0 0]]  

    """
    # tm_wday range [0, 6], Monday is 0
    vec = [time.strptime(str(t[:8],encoding='utf-8'), '%Y%m%d').tm_wday for t in timestamps]  # python3
    # vec = [time.strptime(t[:8], '%Y%m%d').tm_wday for t in timestamps]  # python2
    ret = []
    for i in vec:
        v = [0 for _ in range(7)]
        v[i] = 1
        if i >= 5:
            v.append(0)  # weekend
        else:
            v.append(1)  # weekday
        ret.append(v)
    return np.asarray(ret)


if __name__ == "__main__":
    # t = ['2013-06-30 23:30:00']#
    t= [b'2018120505', b'2018120106']
    print(timestamp2vec(t))
    print([0 for _ in range(7)])

三、STMatrix.py

1.class STMatrix(object)

括号里的object是继承的父类。

__ init__ (self)函数会在创建类时自动执行,必须包含一个参数self。这里self就是指的实例对象本身。

super(STMatrix, self).__init__()就是继承了父类object的__init__()中的内容。

assert是确定这个语句是否正确,如果后面错误,会报错。

check_complete和make_index是类中定义的函数,在下文分析。

CheckComplete,好吧,其实刚开始我没看类中的函数,我不知道这个参数的作用。我问了下灵码,它说:“在给出的代码段中,self.check_complete() 是一个方法调用,它应该是在类 STMatrix 中定义的一个方法。这个方法的具体作用是检查数据集的完整性。确保 timestamps 列表中的时间戳是连续的,没有跳过任何一个预期的时间点。确认每个时间戳都有对应的数据条目,即没有缺失的数据点。如果发现数据不完整,可能会记录一条错误消息或抛出一个异常,提示用户数据存在问题。对于时间序列分析而言,数据的连续性和完整性是非常重要的。”

2.函数make_index

这里创建了一个空字典get_index。

enumerate函数,举个例子。a = [‘sunny’,’cat’,’grass’,’kite’],b = enumerate(a),那么for i in b:

print(i),得到(0, 'sunny')(1, 'cat')(2, 'grass')(3, 'kite')。注意b返回值是一个地址,看其中的元素需要遍历出来。

3.函数check_complete

pd.DateOffset()函数,见名知意,是关于时间偏移的。举个例子:

t = pd.Timestamp('2024-09-25 07:30:00')
print(t)
t2 = t + pd.DateOffset(months=1)
print(t2)
t3 = t + pd.DateOffset(hours=1)
print(t3)
t4 = t + pd.DateOffset(minutes=1)
print(t4)

t5 = t + pd.DateOffset(months=-1)
print(t5)

运行结果:

2024-09-25 07:30:00

2024-10-25 07:30:00

2024-09-25 08:30:00

2024-09-25 07:31:00

2024-08-25 07:30:00

也可以看出来这里的参数正负都可以

代码里面算出来就是每次偏移30min(根据论文中的数据,T参数的值一直是48)

pd_timestamps的数据是这样的:[Timestamp('2013-07-01 00:00:00'), Timestamp('2013-07-01 00:30:00')](根据前面__init__函数提到的,它是string2timestamp(timestamps, T=self.T)这个函数的返回)

If条件的意思是,取出一个时间戳,这个时间加上30min的那个时间不在时间戳列表里,那么表明数据不是连续的(如果连续的话,每30min就有一个时间戳)

append("(%s -- %s)" % (pd_timestamps[i - 1], pd_timestamps[i]))这里%s是一个格式化,用后面元组里的元素替换,最后返回一个字符串。举个例子:

"(%s -- %s)" % ('2024-09-25 07:30:00', '2024-09-25 09:00:00')

print("(%s -- %s)" % ('2024-09-25 07:30:00', '2024-09-25 09:00:00'))

返回(2024-09-25 07:30:00 -- 2024-09-25 09:00:00)

然后把这个加在missing_stamps中,表示在这两个时间点之间还缺了时间戳

最后assert确保时间是连续的

4.函数get_matrix

出现的get_index是一个字典,在上文的make_index函数中提到了。字典的键是时间戳,字典的值是时间戳的序号。这里self.get_index[timestamp]先取到这个时间戳的序号,再在data里根据这个序号找数据

5.函数save

pass了……pass是个占位符,什么都不做,就占个位置

6.函数check_it

depends上文没出现啊,最近的是在下文的create_dataset函数中

这里不妨碍函数的理解,就是看这个depends中的元素是不是字典self.get_index[timestamp]中的键

7.函数create_dataset

参数len_closeness=3,len_period=3都是论文中提到的长度,这里我不明白这长度是什么意思。是不是这样—拿trend来说,每7天一个间隔,这个长度等于3的意思是拿出3个间隔,也就是三周的数据吗?仅供猜测,望大佬指点。

参数TrendInterval=7表示trend块的时间间隔是7天即一周,PeriodInterval=1表示period块的时间间隔是1天。这两个变量都是论文中规定的大小,可以再回顾一下论文。

offset_frame就是时间偏移30min。可是他为什么起这么个名字?

写出来,depends = [range(1,4),[1*48*1,1*48*2,1*48*3],[7*48*1,7*48*2,7*48*3]]这里乘法应该是要算出最终得数的,但这样写比较直观。我猜测前面range(1,4)要写成[1,2,3]的形式吧,用range感觉怪怪的。对了,后面[PeriodInterval * self.T * j for j in range(1, len_period + 1)]是个列表生成式,简单的一个例子是[i*2 for i in range(3)],结果是[0,2,4]

self.T * TrendInterval * len_trend指的是trend数据时间戳的个数,以此类推另外两个period和closeness的个数

max的这个i 指的是三组时间戳个数中最大的一个

While 保证数据集pd_stamps中有所需要的数量的足够的数据

每次while循环

先看从depends中拿出的depend,分别是range(1,4)、[48, 96, 144]、 [336, 672, 1008]

最开始Flag是True的,重新赋值了Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend]),后面又是一个列表生成式。depend中存的是每块时间间隔(trend是1周,period是1天)有的时间戳的个数。也就从需要的时间的最后一个值,这里也就是下标为i的值,一段时间间隔一段时间间隔往前偏移,看看这一个个偏移后的时间是否在字典get_index中。一旦有时间点不在,Flag成了False,i要再加上1,直接跳过后面代码直接进行下一次while循环。这里出错就重新一次次循环,是为了得到一段完整的时间序列吗?可是前面assert len(missing_timestamps) == 0不是已经保证是连续的吗?这里我不是很清楚。

直到不出错,然后给x_c,x_p,x_t赋值。注意这里的i已经是整个数据集中的下标了,不是所需要数据的下标。

拿x_c来说,啊这里我看到代码中大佬的注释明白了!我第一次读这个注释的时候还云里雾里的,看来多看几遍确实有助于理解。说正事:

刚刚我不明白的len_closeness现在明白了,就是x_c的长度是3,比如说,其中一个x_c是[Timestamp‘2024-09-25 09:00:00’,Timestamp‘2024-09-25 08:30:00’,Timestamp‘2024-09-25 08:00:00’],包括三个时间戳,而且前面时间离现在时间近

x_c这个列表中有i之前的三个相邻的时间戳。同理

x_p这个列表中有三个i之前间隔一天的时间戳

x_t这个列表中有三个i之前间隔为一周的时间戳

y就是i处实际的时间戳

np.vstack就是把列表(也可以看作是一维向量)顺着行排起来。举个例子:

a = np.array([1,2,3])

b = np.array([7,8,9])

print(np.vstack(a,b))结果是:

[[1 2 3]

[7 8 9]]

(相对地,np.hstack就是水平顺着列排起来)

最后i+1是想把所有能够获得的而且符合条件的时间都放进去吗?

最后得到了XC、XP、XT,二维列表。Y、timestamps_Y是一维列表。

这里timestamps_Y与Y的区别?

timestamps_Y的元素是timestamps,是这个格式[b'2024092526', b'2024092527']

Y的元素是pd_timestamps,是这个格式[Timestamp('2024-09-25 12:30:00'), Timestamp('2024-09-25 13:00:00')]

8.if __name__ == '__main__':

作用跟上文提到的一样

9.STMatrix.py源代码

# Acknowledgement: This code is taken from https://github.com/TolicWang/DeepST
import numpy as np
import pandas as pd
from .timestamp import string2timestamp


class STMatrix(object):
    """docstring for STMatrix"""

    def __init__(self, data, timestamps, T=48, CheckComplete=True):
        super(STMatrix, self).__init__()
        assert len(data) == len(timestamps)
        self.data = data
        self.timestamps = timestamps# [b'2013070101', b'2013070102']
        self.T = T
        self.pd_timestamps = string2timestamp(timestamps, T=self.T)
        if CheckComplete:
            self.check_complete()
        # index
        self.make_index()  # 将时间戳:做成一个字典,也就是给每个时间戳一个序号

    def make_index(self):
        self.get_index = dict()
        for i, ts in enumerate(self.pd_timestamps):
            self.get_index[ts] = i

    def check_complete(self):
        missing_timestamps = []
        offset = pd.DateOffset(minutes=24 * 60 // self.T)
        pd_timestamps = self.pd_timestamps
        i = 1
        while i < len(pd_timestamps):
            if pd_timestamps[i - 1] + offset != pd_timestamps[i]:
                missing_timestamps.append("(%s -- %s)" % (pd_timestamps[i - 1], pd_timestamps[i]))
            i += 1
        for v in missing_timestamps:
            print(v)
        assert len(missing_timestamps) == 0

    def get_matrix(self, timestamp):  # 给定时间戳返回对于的数据
        return self.data[self.get_index[timestamp]]

    def save(self, fname):
        pass

    def check_it(self, depends):
        for d in depends:
            if d not in self.get_index.keys():
                return False
        return True

    def create_dataset(self, len_closeness=3, len_trend=3, TrendInterval=7, len_period=3, PeriodInterval=1):
        """current version

        """
        # offset_week = pd.DateOffset(days=7)
        offset_frame = pd.DateOffset(minutes=24 * 60 // self.T)  # 时间偏移 minutes = 30
        XC = []
        XP = []
        XT = []
        Y = []
        timestamps_Y = []
        depends = [range(1, len_closeness + 1),
                   [PeriodInterval * self.T * j for j in range(1, len_period + 1)],
                   [TrendInterval * self.T * j for j in range(1, len_trend + 1)]]
        # print depends # [range(1, 4), [48, 96, 144], [336, 672, 1008]]
        i = max(self.T * TrendInterval * len_trend, self.T * PeriodInterval * len_period, len_closeness)
        while i < len(self.pd_timestamps):
            Flag = True
            for depend in depends:
                if Flag is False:
                    break
                Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend])

            if Flag is False:
                i += 1
                continue
            x_c = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[0]]
            # 取当前时刻的前3个时间片的数据数据构成“邻近性”模块中一个输入序列
            # 例如当前时刻为[Timestamp('2013-07-01 00:00:00')]
            # 则取:
            # [Timestamp('2013-06-30 23:30:00'), Timestamp('2013-06-30 23:00:00'), Timestamp('2013-06-30 22:30:00')]
            #  三个时刻所对应的in-out flow为一个序列
            x_p = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[1]]
            # 取当前时刻 前 1*PeriodInterval,2*PeriodInterval,...,len_period*PeriodInterval
            # 天对应时刻的in-out flow 作为一个序列,例如按默认值为 取前1、2、3天同一时刻的In-out flow
            x_t = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[2]]
            # 取当前时刻 前 1*TrendInterval,2*TrendInterval,...,len_trend*TrendInterval
            # 天对应时刻的in-out flow 作为一个序列,例如按默认值为 取 前7、14、21天同一时刻的In-out flow
            y = self.get_matrix(self.pd_timestamps[i])
            if len_closeness > 0:
                XC.append(np.vstack(x_c))
                # a.shape=[2,32,32] b.shape=[2,32,32] c=np.vstack((a,b)) -->c.shape = [4,32,32]
            if len_period > 0:
                XP.append(np.vstack(x_p))
            if len_trend > 0:
                XT.append(np.vstack(x_t))
            Y.append(y)
            timestamps_Y.append(self.timestamps[i])#[]
            i += 1
        XC = np.asarray(XC)  # 模拟 邻近性的 数据 [?,6,32,32]
        XP = np.asarray(XP)  # 模拟 周期性的 数据 隔天
        XT = np.asarray(XT)  # 模拟 趋势性的 数据 隔周
        Y = np.asarray(Y)# [?,2,32,32]
        print("XC shape: ", XC.shape, "XP shape: ", XP.shape, "XT shape: ", XT.shape, "Y shape:", Y.shape)
        return XC, XP, XT, Y, timestamps_Y


if __name__ == '__main__':
    # depends = [range(1, 3 + 1),
    #            [1 * 48 * j for j in range(1, 3 + 1)],
    #            [7 * 48 * j for j in range(1, 3 + 1)]]
    # print(depends)
    # print([j for j in depends[0]])
    str = ['2013070101']
    t = string2timestamp(str)
    offset_frame = pd.DateOffset(minutes=24 * 60 // 48)  # 时间偏移 minutes = 30
    print(t)
    o = [t[0] - j * offset_frame for j in range(1, 4)]
    print(o)

四、后续

后续会继续更新TaxiBj.py和STResNet.py代码的详细解析,希望能够与大家分享讨论。

注:本人是刚接触论文的大二生,全文的代码分析都是自己的理解,难免会有错误,欢迎指出!我会努力改正的!望各位大佬指点!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值