论文:Deep Spatio-Temporal Residual Networks for Citywide Crowd Flows Prediction
代码我看不太懂……我找解析,找不到……
所以,我根据自己的理解写了一下pytorch版代码的详解:
代码完整复现请见这位大佬的,我只是自己分析了代码。
看我解析的时候一定要对照源码思考!!!
放心我是小白,我的基础不是很好,所以我也写的详细,对大家来说肯定是比较容易理解滴。
本文先分析预处理包preprocessing中的三个部分:MaxMinNormalization.py、timestamp.py、STMatrix.py
一、MaxMinNormalization.py
1.公式
原来的公式只有x* = (x-min)/(max-min)
但最后的范围是[-1,1],加了一步X=X*2 -1
只看这个MaxMinNormalization归一化的话,只有x* = (x-min)/(max-min)这一步,但得到的范围就是[0,1]了
2.def inverse_transform(self, X):
X = (X + 1.) / 2.
X = 1. * X * (self._max - self._min) + self._min
return X
这个函数就是把归一化的公式又倒回去了
3.MaxMinNormalization.py源代码
# Acknowledgement: This code is taken from https://github.com/TolicWang/DeepST
import numpy as np
np.random.seed(1337) # for reproducibility
class MinMaxNormalization(object):
"""
MinMax Normalization --> [-1, 1]
x = (x - min) / (max - min).
x = x * 2 - 1
"""
def __init__(self):
pass
def fit(self, X):
self._min = X.min()
self._max = X.max()
print("min:", self._min, "max:", self._max)
def transform(self, X):
X = 1. * (X - self._min) / (self._max - self._min)
X = X * 2. - 1.
return X
def fit_transform(self, X):
self.fit(X)
return self.transform(X)
def inverse_transform(self, X):
X = (X + 1.) / 2.
X = 1. * X * (self._max - self._min) + self._min
return X
二、timestamp.py
1.函数string2timestamp
原论文在Table1中写到对于TaxiBJ数据,采样间隔是30min,采样数据大小是(32,32)
一天24小时,每隔30min采样一次就是每天采样48次,也就是string2timestamp(strings, T=48)出现的参数T=48
time_per_slot=0.5,每次采样的时间是0.5h
num_per_T=2,每个小时中有两次采样
输入的str格式是[b'2013070101', b'2013070102'],下标为0123代表年,45代表月,67代表日,89代表这天的第几次采样,下标89所组成的二位数最大是48.比如,'2013070101'代表‘2013.7.1的00:00点’,'2013070105'代表‘2013.7.1的02:00点’
slot =int(t[8:]) - 1其中减掉1的意义?当采样次数为01时,实际的时间为半夜的00:00点
hour=int(slot * time_per_slot)现在的时间,就等于采样的次数乘上每次采样的时间.如果时间是02:30,按照这个式子算出来是2.5,所以要取整
minute=(slot % num_per_T) * int(60.0 * time_per_slot),%表示采样次数整除2后取余,也就是看现在的时间是一个整的小时02:00之类的还是一个半的小时02:30.如果是整的小时,前面式子求出来是0,整个式子取0;如果是半的小时,前面的式子求出来是1,后面式子求出来是30,最后就是minute = 30
整个函数最后返回值[Timestamp('2013-07-01 00:00:00'), Timestamp('2013-07-01 00:30:00')]
2.函数timestamp2vec
补充知识:
(1)time.striptime()根据指定的格式,把一个时间字符串解析为时间元组。格式化符号有很多,代码中出现的%Y表示四位数的年份(两位数年份用y%),%m表示月份,%d表示月内的一天。举个例子,import time,time_a = time.strptime("20240924","%Y%m%d"),print(time_a)得time.struct_time(tm_year=2024, tm_mon=9, tm_mday=24, tm_hour=0, tm_min=0, tm_sec=0, tm_wday=1, tm_yday=268, tm_isdst=-1)。其中,tm_wday表示每周的第几天,取值是0-6,这里取的1表示是周二。tm_isdst表示是否是夏令时,取值是1是、0不是、-1未知,默认是-1.
(2)np.asarray()主要是要与np.array()区分,都是返回一个ndarray(n维数组对象)。前者np.asarray()不会改变地址。也就是说,假设有个数组a,b = np.asarray(a),c =np.array(a),那么更改a后,再输出bc,会发现b跟a一样变了,c还是原来的a。
这里vec就是取出了每个时间是周几的信息。
然后对vec遍历,每次遍历:先建一个数组v=[0,0,0,0,0,0,0]有七个位置,对应着每星期的七天。拿出来的这个数据是周几,第几位0就变成1。比如,遍历的i=3,数组现在就成了v=[0,0,1,0,0,0,0]。像是叫独热编码(把离散的数据数字化)还是稀疏矩阵,我猜的。然后有个if如果是周末的话,列表v后面再加个0,如果是工作日,就加个1.比如这里是周三,最后v=[0,0,1,0,0,0,0,1].最后把v这个列表作为一个元素加到列表ret中。
这样ret就是列表嵌套列表了,像是二维的。
3.if__name__ ==__main__:
对于其中的内容,如果这个.py文件作为包被导入别的文件中,这些内容不会被执行,只在本原文件中才被执行。
4.timestamp.py源代码
# Acknowledgement: This code is taken from https://github.com/TolicWang/DeepST
import time
import pandas as pd
import numpy as np
from datetime import datetime
def string2timestamp(strings, T=48):
"""
:param strings:
:param T:
:return:
example:
str = [b'2013070101', b'2013070102']
print(string2timestamp(str))
[Timestamp('2013-07-01 00:00:00'), Timestamp('2013-07-01 00:30:00')]
"""
timestamps = []
time_per_slot = 24.0 / T
num_per_T = T // 24
for t in strings:
year, month, day, slot = int(t[:4]), int(t[4:6]), int(t[6:8]), int(t[8:]) - 1
timestamps.append(pd.Timestamp(datetime(year, month, day, hour=int(slot * time_per_slot),
minute=(slot % num_per_T) * int(60.0 * time_per_slot))))
return timestamps
def timestamp2vec(timestamps):
"""
:param timestamps:
:return:
exampel:
[b'2018120505', b'2018120106']
#[[0 0 1 0 0 0 0 1]
[0 0 0 0 0 1 0 0]]
"""
# tm_wday range [0, 6], Monday is 0
vec = [time.strptime(str(t[:8],encoding='utf-8'), '%Y%m%d').tm_wday for t in timestamps] # python3
# vec = [time.strptime(t[:8], '%Y%m%d').tm_wday for t in timestamps] # python2
ret = []
for i in vec:
v = [0 for _ in range(7)]
v[i] = 1
if i >= 5:
v.append(0) # weekend
else:
v.append(1) # weekday
ret.append(v)
return np.asarray(ret)
if __name__ == "__main__":
# t = ['2013-06-30 23:30:00']#
t= [b'2018120505', b'2018120106']
print(timestamp2vec(t))
print([0 for _ in range(7)])
三、STMatrix.py
1.class STMatrix(object)
括号里的object是继承的父类。
__ init__ (self)函数会在创建类时自动执行,必须包含一个参数self。这里self就是指的实例对象本身。
super(STMatrix, self).__init__()就是继承了父类object的__init__()中的内容。
assert是确定这个语句是否正确,如果后面错误,会报错。
check_complete和make_index是类中定义的函数,在下文分析。
CheckComplete,好吧,其实刚开始我没看类中的函数,我不知道这个参数的作用。我问了下灵码,它说:“在给出的代码段中,self.check_complete() 是一个方法调用,它应该是在类 STMatrix 中定义的一个方法。这个方法的具体作用是检查数据集的完整性。确保 timestamps 列表中的时间戳是连续的,没有跳过任何一个预期的时间点。确认每个时间戳都有对应的数据条目,即没有缺失的数据点。如果发现数据不完整,可能会记录一条错误消息或抛出一个异常,提示用户数据存在问题。对于时间序列分析而言,数据的连续性和完整性是非常重要的。”
2.函数make_index
这里创建了一个空字典get_index。
enumerate函数,举个例子。a = [‘sunny’,’cat’,’grass’,’kite’],b = enumerate(a),那么for i in b:
print(i),得到(0, 'sunny')(1, 'cat')(2, 'grass')(3, 'kite')。注意b返回值是一个地址,看其中的元素需要遍历出来。
3.函数check_complete
pd.DateOffset()函数,见名知意,是关于时间偏移的。举个例子:
t = pd.Timestamp('2024-09-25 07:30:00')
print(t)
t2 = t + pd.DateOffset(months=1)
print(t2)
t3 = t + pd.DateOffset(hours=1)
print(t3)
t4 = t + pd.DateOffset(minutes=1)
print(t4)
t5 = t + pd.DateOffset(months=-1)
print(t5)
运行结果:
2024-09-25 07:30:00
2024-10-25 07:30:00
2024-09-25 08:30:00
2024-09-25 07:31:00
2024-08-25 07:30:00
也可以看出来这里的参数正负都可以
代码里面算出来就是每次偏移30min(根据论文中的数据,T参数的值一直是48)
pd_timestamps的数据是这样的:[Timestamp('2013-07-01 00:00:00'), Timestamp('2013-07-01 00:30:00')](根据前面__init__函数提到的,它是string2timestamp(timestamps, T=self.T)这个函数的返回)
If条件的意思是,取出一个时间戳,这个时间加上30min的那个时间不在时间戳列表里,那么表明数据不是连续的(如果连续的话,每30min就有一个时间戳)
append("(%s -- %s)" % (pd_timestamps[i - 1], pd_timestamps[i]))这里%s是一个格式化,用后面元组里的元素替换,最后返回一个字符串。举个例子:
"(%s -- %s)" % ('2024-09-25 07:30:00', '2024-09-25 09:00:00')
print("(%s -- %s)" % ('2024-09-25 07:30:00', '2024-09-25 09:00:00'))
返回(2024-09-25 07:30:00 -- 2024-09-25 09:00:00)
然后把这个加在missing_stamps中,表示在这两个时间点之间还缺了时间戳
最后assert确保时间是连续的
4.函数get_matrix
出现的get_index是一个字典,在上文的make_index函数中提到了。字典的键是时间戳,字典的值是时间戳的序号。这里self.get_index[timestamp]先取到这个时间戳的序号,再在data里根据这个序号找数据
5.函数save
pass了……pass是个占位符,什么都不做,就占个位置
6.函数check_it
depends上文没出现啊,最近的是在下文的create_dataset函数中
这里不妨碍函数的理解,就是看这个depends中的元素是不是字典self.get_index[timestamp]中的键
7.函数create_dataset
参数len_closeness=3,len_period=3都是论文中提到的长度,这里我不明白这长度是什么意思。是不是这样—拿trend来说,每7天一个间隔,这个长度等于3的意思是拿出3个间隔,也就是三周的数据吗?仅供猜测,望大佬指点。
参数TrendInterval=7表示trend块的时间间隔是7天即一周,PeriodInterval=1表示period块的时间间隔是1天。这两个变量都是论文中规定的大小,可以再回顾一下论文。
offset_frame就是时间偏移30min。可是他为什么起这么个名字?
写出来,depends = [range(1,4),[1*48*1,1*48*2,1*48*3],[7*48*1,7*48*2,7*48*3]]这里乘法应该是要算出最终得数的,但这样写比较直观。我猜测前面range(1,4)要写成[1,2,3]的形式吧,用range感觉怪怪的。对了,后面[PeriodInterval * self.T * j for j in range(1, len_period + 1)]是个列表生成式,简单的一个例子是[i*2 for i in range(3)],结果是[0,2,4]
self.T * TrendInterval * len_trend指的是trend数据时间戳的个数,以此类推另外两个period和closeness的个数
max的这个i 指的是三组时间戳个数中最大的一个
While 保证数据集pd_stamps中有所需要的数量的足够的数据
每次while循环:
先看从depends中拿出的depend,分别是range(1,4)、[48, 96, 144]、 [336, 672, 1008]
最开始Flag是True的,重新赋值了Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend]),后面又是一个列表生成式。depend中存的是每块时间间隔(trend是1周,period是1天)有的时间戳的个数。也就从需要的时间的最后一个值,这里也就是下标为i的值,一段时间间隔一段时间间隔往前偏移,看看这一个个偏移后的时间是否在字典get_index中。一旦有时间点不在,Flag成了False,i要再加上1,直接跳过后面代码直接进行下一次while循环。这里出错就重新一次次循环,是为了得到一段完整的时间序列吗?可是前面assert len(missing_timestamps) == 0不是已经保证是连续的吗?这里我不是很清楚。
直到不出错,然后给x_c,x_p,x_t赋值。注意这里的i已经是整个数据集中的下标了,不是所需要数据的下标。
拿x_c来说,啊这里我看到代码中大佬的注释明白了!我第一次读这个注释的时候还云里雾里的,看来多看几遍确实有助于理解。说正事:
刚刚我不明白的len_closeness现在明白了,就是x_c的长度是3,比如说,其中一个x_c是[Timestamp‘2024-09-25 09:00:00’,Timestamp‘2024-09-25 08:30:00’,Timestamp‘2024-09-25 08:00:00’],包括三个时间戳,而且前面时间离现在时间近
x_c这个列表中有i之前的三个相邻的时间戳。同理
x_p这个列表中有三个i之前间隔一天的时间戳
x_t这个列表中有三个i之前间隔为一周的时间戳
y就是i处实际的时间戳
np.vstack就是把列表(也可以看作是一维向量)顺着行排起来。举个例子:
a = np.array([1,2,3])
b = np.array([7,8,9])
print(np.vstack(a,b))结果是:
[[1 2 3]
[7 8 9]]
(相对地,np.hstack就是水平顺着列排起来)
最后i+1是想把所有能够获得的而且符合条件的时间都放进去吗?
最后得到了XC、XP、XT,二维列表。Y、timestamps_Y是一维列表。
这里timestamps_Y与Y的区别?
timestamps_Y的元素是timestamps,是这个格式[b'2024092526', b'2024092527']
Y的元素是pd_timestamps,是这个格式[Timestamp('2024-09-25 12:30:00'), Timestamp('2024-09-25 13:00:00')]
8.if __name__ == '__main__':
作用跟上文提到的一样
9.STMatrix.py源代码
# Acknowledgement: This code is taken from https://github.com/TolicWang/DeepST
import numpy as np
import pandas as pd
from .timestamp import string2timestamp
class STMatrix(object):
"""docstring for STMatrix"""
def __init__(self, data, timestamps, T=48, CheckComplete=True):
super(STMatrix, self).__init__()
assert len(data) == len(timestamps)
self.data = data
self.timestamps = timestamps# [b'2013070101', b'2013070102']
self.T = T
self.pd_timestamps = string2timestamp(timestamps, T=self.T)
if CheckComplete:
self.check_complete()
# index
self.make_index() # 将时间戳:做成一个字典,也就是给每个时间戳一个序号
def make_index(self):
self.get_index = dict()
for i, ts in enumerate(self.pd_timestamps):
self.get_index[ts] = i
def check_complete(self):
missing_timestamps = []
offset = pd.DateOffset(minutes=24 * 60 // self.T)
pd_timestamps = self.pd_timestamps
i = 1
while i < len(pd_timestamps):
if pd_timestamps[i - 1] + offset != pd_timestamps[i]:
missing_timestamps.append("(%s -- %s)" % (pd_timestamps[i - 1], pd_timestamps[i]))
i += 1
for v in missing_timestamps:
print(v)
assert len(missing_timestamps) == 0
def get_matrix(self, timestamp): # 给定时间戳返回对于的数据
return self.data[self.get_index[timestamp]]
def save(self, fname):
pass
def check_it(self, depends):
for d in depends:
if d not in self.get_index.keys():
return False
return True
def create_dataset(self, len_closeness=3, len_trend=3, TrendInterval=7, len_period=3, PeriodInterval=1):
"""current version
"""
# offset_week = pd.DateOffset(days=7)
offset_frame = pd.DateOffset(minutes=24 * 60 // self.T) # 时间偏移 minutes = 30
XC = []
XP = []
XT = []
Y = []
timestamps_Y = []
depends = [range(1, len_closeness + 1),
[PeriodInterval * self.T * j for j in range(1, len_period + 1)],
[TrendInterval * self.T * j for j in range(1, len_trend + 1)]]
# print depends # [range(1, 4), [48, 96, 144], [336, 672, 1008]]
i = max(self.T * TrendInterval * len_trend, self.T * PeriodInterval * len_period, len_closeness)
while i < len(self.pd_timestamps):
Flag = True
for depend in depends:
if Flag is False:
break
Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend])
if Flag is False:
i += 1
continue
x_c = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[0]]
# 取当前时刻的前3个时间片的数据数据构成“邻近性”模块中一个输入序列
# 例如当前时刻为[Timestamp('2013-07-01 00:00:00')]
# 则取:
# [Timestamp('2013-06-30 23:30:00'), Timestamp('2013-06-30 23:00:00'), Timestamp('2013-06-30 22:30:00')]
# 三个时刻所对应的in-out flow为一个序列
x_p = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[1]]
# 取当前时刻 前 1*PeriodInterval,2*PeriodInterval,...,len_period*PeriodInterval
# 天对应时刻的in-out flow 作为一个序列,例如按默认值为 取前1、2、3天同一时刻的In-out flow
x_t = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[2]]
# 取当前时刻 前 1*TrendInterval,2*TrendInterval,...,len_trend*TrendInterval
# 天对应时刻的in-out flow 作为一个序列,例如按默认值为 取 前7、14、21天同一时刻的In-out flow
y = self.get_matrix(self.pd_timestamps[i])
if len_closeness > 0:
XC.append(np.vstack(x_c))
# a.shape=[2,32,32] b.shape=[2,32,32] c=np.vstack((a,b)) -->c.shape = [4,32,32]
if len_period > 0:
XP.append(np.vstack(x_p))
if len_trend > 0:
XT.append(np.vstack(x_t))
Y.append(y)
timestamps_Y.append(self.timestamps[i])#[]
i += 1
XC = np.asarray(XC) # 模拟 邻近性的 数据 [?,6,32,32]
XP = np.asarray(XP) # 模拟 周期性的 数据 隔天
XT = np.asarray(XT) # 模拟 趋势性的 数据 隔周
Y = np.asarray(Y)# [?,2,32,32]
print("XC shape: ", XC.shape, "XP shape: ", XP.shape, "XT shape: ", XT.shape, "Y shape:", Y.shape)
return XC, XP, XT, Y, timestamps_Y
if __name__ == '__main__':
# depends = [range(1, 3 + 1),
# [1 * 48 * j for j in range(1, 3 + 1)],
# [7 * 48 * j for j in range(1, 3 + 1)]]
# print(depends)
# print([j for j in depends[0]])
str = ['2013070101']
t = string2timestamp(str)
offset_frame = pd.DateOffset(minutes=24 * 60 // 48) # 时间偏移 minutes = 30
print(t)
o = [t[0] - j * offset_frame for j in range(1, 4)]
print(o)
四、后续
后续会继续更新TaxiBj.py和STResNet.py代码的详细解析,希望能够与大家分享讨论。
注:本人是刚接触论文的大二生,全文的代码分析都是自己的理解,难免会有错误,欢迎指出!我会努力改正的!望各位大佬指点!