论文辅助笔记:TEMPO 之 utils.py

0 导入库

from typing import Tuple
import random
import numpy as np
import torch
from statsmodels.tsa.seasonal import STL

1 EarlyStopping

  • 提供了一个早停机制,用于在模型训练过程中监控验证集上的损失
  • 如果损失停止改进,则停止训练

1.1 __init__

class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        #早停的容忍度,如果连续 patience 次验证损失没有改善,则停止训练。

        self.verbose = verbose
        #决定是否输出详细信息


        self.counter = 0
        #记录连续未改善验证损失的次数


        self.best_score = None
        #用于存储目前为止最佳的验证损失分数

        self.early_stop = False
        #一个布尔值,指示是否应该停止训练


        self.val_loss_min = np.Inf
        #存储目前为止最小的验证损失

        self.delta = delta
        #一个阈值,用于决定损失的改善幅度

1.2 __call__ 在训练过程中监控验证损失

def __call__(self, val_loss, model, path):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
            #如果这是第一次调用 __call__,初始化 best_score 为 score 并保存模型。
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
            '''
            如果 score < self.best_score + self.delta,则说明损失没有显著改善
            
            增加 counter 并检查是否超过 patience,如果超过则停止训练
            '''
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
            self.counter = 0
            '''
            如果 score > self.best_score + self.delta,更新 best_score 并保存模型
            然后将 counter 重置为零
            '''

1.3 save_checkpoint 在验证损失降低时保存模型

def save_checkpoint(self, val_loss, model, path):
        if self.verbose:
            print(
                f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ..."
            )
        torch.save(model.state_dict(), path + "/" + "checkpoint.pth")
        #使用 torch.save() 保存模型的状态字典
        self.val_loss_min = val_loss
        

2 StandardScaler

实现数据标准化

2.1 __init__

class StandardScaler:
    def __init__(self):
        self.mean = 0.0
        self.std = 1.0

2.2  fit

计算并更新 self.meanself.std

def fit(self, data):
        self.mean = data.mean(0)
        self.std = data.std(0)

 2.3  transform

   将数据转换为标准化形式

def transform(self, data):
        mean = (
            torch.from_numpy(self.mean).type_as(data).to(data.device)
            if torch.is_tensor(data)
            else self.mean
        )
        std = (
            torch.from_numpy(self.std).type_as(data).to(data.device)
            if torch.is_tensor(data)
            else self.std
        )
        '''
        mean 和 std 的类型转换:
            根据 data 是 torch.Tensor 还是 numpy 数组
            将 self.mean 和 self.std 转换为相应类型,以确保类型匹配
        '''
        return (data - mean) / std

 2.4 inverse_transform

将标准化后的数据还原

    def inverse_transform(self, data):
        mean = (
            torch.from_numpy(self.mean).type_as(data).to(data.device)
            if torch.is_tensor(data)
            else self.mean
        )
        std = (
            torch.from_numpy(self.std).type_as(data).to(data.device)
            if torch.is_tensor(data)
            else self.std
        )

        '''
        mean 和 std 的类型转换:
            根据 data 是 torch.Tensor 还是 numpy 数组
            将 self.mean 和 self.std 转换为相应类型,以确保类型匹配
        '''

        if data.shape[-1] != mean.shape[-1]:
            mean = mean[-1:]
            std = std[-1:]


        return (data * std) + mean
        '''
        通过 (data * std) + mean 将标准化后的数据还原为原始形式
        '''

3 decompose

使用STL,将时间序列分解为趋势、季节性和残差成分

def decompose(
    x: torch.Tensor, period: int = 7
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    
    #x:输入的一维时间序列,类型为 torch.Tensor,形状为 (1, seq_len)
    x = x.squeeze(0).cpu().numpy()
    '''
    首先调用 squeeze(0) 将 x 的第一个维度去掉
    然后通过 cpu().numpy() 将 x 转换为 numpy 数组,以便 STL 分解函数使用
    '''


    decomposed = STL(x, period=period).fit()
    '''
    调用 STL(x, period=period).fit() 对 x 进行分解,并返回分解结果 decomposed
    
    其中包含了 trend(趋势)、seasonal(季节性)和 resid(残差)成分
    '''


    trend = decomposed.trend.astype(np.float32)
    seasonal = decomposed.seasonal.astype(np.float32)
    residual = decomposed.resid.astype(np.float32)
    '''
    将 decomposed 中的各个成分转换为 numpy 数组,并转为 float32 类型
    '''


    return (
        torch.from_numpy(trend).unsqueeze(0),
        torch.from_numpy(seasonal).unsqueeze(0),
        torch.from_numpy(residual).unsqueeze(0),
    )
    '''
    将它们转换为 torch.Tensor
    并使用 unsqueeze(0) 将其包装为 (1, seq_len) 的张量,以匹配输入张量的形状
    '''

4 set_seed

为 Python 中的各种随机生成器设置种子

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值