DeepST/deepst/datasets/STMatrix.py 代码解析

from __future__ import print_function
import os
import pandas as pd
import numpy as np

from . import load_stdata
from ..config import Config
from ..utils import string2timestamp
//from . import,“.”  代表使用相对路径导入,即从当前项目中寻找需要导入的包或数,from..import绝对导入语句。一个"."表示往上跳一级,假如A包含B和C,要往B里import一个东西,可以写from ..A(两个".",跳的比A高一级了,可) import C.

class STMatrix(object):
    """docstring for STMatrix"""//STMatrix的字符串文本

    def __init__(self, data, timestamps, T=48, CheckComplete=True)://定义构造函数
        super(STMatrix, self).__init__()//#super表继承,这里继承自己
        assert len(data) == len(timestamps)//assert:断言 前置条件断言:代码执行之前必须具备的特性,如果不满足程序就会中断
        self.data = data
        self.timestamps = timestamps
        self.T = T
        self.pd_timestamps = string2timestamp(timestamps, T=self.T)//字符转换成时间戳,timestamp = time.time(),为float型,时间戳是计算机能够识别的时间;时间字符串是人能够看懂的时间;元组则是用来操作时间的。
        if CheckComplete:
            self.check_complete()//转成每半小时每半小时
        # index
        self.make_index()//以每半小时的时间做索引

    def make_index(self)://可能是做一个索引
        self.get_index = dict()//字典 包括索引和对应的值
        for i, ts in enumerate(self.pd_timestamps)://enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
            self.get_index[ts] = i

    def check_complete(self):
        missing_timestamps = []
        offset = pd.DateOffset(minutes=24 * 60 // self.T)  //T=48,"//"表示取整除 - 返回商的整数部分(向下取整)。DateOffset可按指定的日历日时间段偏移日期时间,可能是把时间归成每半个小时一个。
        pd_timestamps = self.pd_timestamps
        i = 1
        while i < len(pd_timestamps):
            if pd_timestamps[i-1] + offset != pd_timestamps[i]:
                missing_timestamps.append("(%s -- %s)" % (pd_timestamps[i-1], pd_timestamps[i]))
            i += 1
        for v in missing_timestamps:
            print(v)
        assert len(missing_timestamps) == 0

    def get_matrix(self, timestamp)://获得timestamp为索引的data
        return self.data[self.get_index[timestamp]]

    def save(self, fname):
        pass     //Python pass 是空语句,是为了保持程序结构的完整性。pass 不做任何事情,一般用做占位语句。该处的 pass 便是占据一个位置,因为如果定义一个空函数程序会报错,当你没有想好函数的内容是可以用 pass 填充,使程序可以正常运行。

    def check_it(self, depends):   //检查关键字,应该要是对的
        for d in depends:
            if d not in self.get_index.keys():
                return False
        return True

    def create_dataset(self, len_closeness=3, len_trend=3, TrendInterval=7, len_period=3, PeriodInterval=1):    //这个函数没细看(太难了),大概讲的是构造出临近性、趋势性、周期性的三种数据形式
        """current version
        """
        # offset_week = pd.DateOffset(days=7)
        offset_frame = pd.DateOffset(minutes=24 * 60 // self.T)
        XC = []
        XP = []
        XT = []
        Y = []
        timestamps_Y = []
        depends = [range(1, len_closeness+1),
                   [PeriodInterval * self.T * j for j in range(1, len_period+1)],
                   [TrendInterval * self.T * j for j in range(1, len_trend+1)]]

        i = max(self.T * TrendInterval * len_trend, self.T * PeriodInterval * len_period, len_closeness)
        while i < len(self.pd_timestamps):
            Flag = True
            for depend in depends:
                if Flag is False:
                    break
                Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend])

            if Flag is False:
                i += 1
                continue
            x_c = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[0]]
            x_p = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[1]]
            x_t = [self.get_matrix(self.pd_timestamps[i] - j * offset_frame) for j in depends[2]]
            y = self.get_matrix(self.pd_timestamps[i])
            if len_closeness > 0:
                XC.append(np.vstack(x_c))
            if len_period > 0:
                XP.append(np.vstack(x_p))
            if len_trend > 0:
                XT.append(np.vstack(x_t))
            Y.append(y)
            timestamps_Y.append(self.timestamps[i])
            i += 1
        XC = np.asarray(XC)
        XP = np.asarray(XP)
        XT = np.asarray(XT)
        Y = np.asarray(Y)
        print("XC shape: ", XC.shape, "XP shape: ", XP.shape, "XT shape: ", XT.shape, "Y shape:", Y.shape)
        return XC, XP, XT, Y, timestamps_Y


if __name__ == '__main__':
    pass    //当**.py**文件被直接运行时,if __name__ ==’__main__'之下的代码块将被运行;当.py文件以模块形式被导入时,if __name__ == '__main__'之下的代码块不被运行。

总代码链接https://github.com/amirkhango/DeepST

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值