论文辅助笔记:Tempo 之 model.py

 0 导入库

import math
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn

from src.modules.transformer import Block
from src.modules.prompt import Prompt
from src.modules.utils import (
    FlattenHead,
    PoolingHead,
    RevIN,
)


1TEMPOConfig

1.1 构造函数

class TEMPOConfig:
    """
    Configuration of a `TEMPO` model.

    Args:
        num_series: 时间序列的数量, N 
        input_len: 输入时间序列的长度, L
        pred_len: 预测时间序列的长度, Y
        block_size: 块的最大长度(openai gpt2 固定)
        n_layer: Transformer 层的数量
        n_head: 多头注意力机制中的头数量
        n_embd: 嵌入维度的数量
        patch_size: 块的大小,用于将输入时间序列分割成多个小块
        patch_stride: 块的步幅,用于指定块之间的重叠程度
        revin: 是否使用 RevIN(归一化和逆变换)
        affine: 在 RevIN 中是否使用仿射变换
        embd_pdrop:嵌入层的 dropout 率
        resid_pdrop: 残差连接的 dropout 率
        attn_pdrop: 注意力层的 dropout 率
        head_type: 输出层的类型,可以是 FlattenHead 或 PoolingHead
        head_pdtop: 输出层的 dropout 率
        individual: 是否为每个组件使用独立的输出层
        lora: 是否使用 LoRA(低秩近似)
        lora_config: LoRA 的配置
        model_type: 模型类型,默认为 gpt2
        interpret: 是否输出组件以便解释
    """

    num_series: int
    input_len: int
    pred_len: int
    patch_size: int
    patch_stride: int
    block_size: int = None
    n_layer: int = None
    n_head: int = None
    n_embd: int = None
    revin: bool = True
    affine: bool = True
    embd_pdrop: float = 0.1
    resid_pdrop: float = 0.1
    attn_pdrop: float = 0.1
    head_type: str = "flatten"
    head_pdtop: float = 0.1
    individual: bool = False
    lora: bool = False
    lora_config: dict = None
    prompt_config: dict = None
    #Prompt 模块的配置
    model_type: str = "gpt2"
    interpret: bool = False

1.2  todict

TEMPOConfig 类实例转换为一个字典

def todict(self):
    return asdict(self)

'''
asdict 是 Python 的 dataclasses 模块提供的一个函数,用于将数据类实例转换为字典。

这个方法将当前实例的所有属性转换为字典键值对,并返回这个字典。
'''

1.3 __contains__

重载了 Python 的 __contains__ 魔术方法,使得 TEMPOConfig 实例可以像字典一样使用 in 操作符来检查属性是否存在。

def __contains__(self, key):
    return key in self.todict()

1.4 __getitem__

重载了 __getitem__ 魔术方法,使得 TEMPOConfig 实例可以像字典一样通过键来获取属性值

def __getitem__(self, key):
    return getattr(self, key)

1.5__setitem__

重载了 __setitem__ 魔术方法,使得 TEMPOConfig 实例可以像字典一样通过键来设置属性值

def __setitem__(self, key, value):
    setattr(self, key, value)

1.6 update

通过一个字典 config 更新 TEMPOConfig 实例的属性

def update(self, config: dict):
    for k, v in config.items():
        setattr(self, k, v)

2 TEMPO

class TEMPO(nn.Module):
    """
    Notation:
        B: 批次大小
        N: 时间序列的数量
        E: 嵌入维度
        P: 块的数量
        PS: patch的大小
        L: 输入时间序列的长度
        Y: 预测时间序列的长度
    """

    models = ("gpt2",)
    #支持的模型类型列表

    head_types = ("flatten", "pooling")
    #支持的输出层类型

    params = {
        "gpt2": dict(block_size=1024, n_head=12, n_embd=768),
    }
    '''
    模型的参数,例如 "gpt2" 模型的块大小、注意力头数和嵌入维度等
    '''

2.1 __init__

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值