Lucidrains 系列项目源码解析(八十二)

.\lucidrains\reformer-pytorch\reformer_pytorch\autopadder.py

# 导入数学库和 PyTorch 库
import math
import torch
from torch import nn
import torch.nn.functional as F

# 导入自定义模块
from reformer_pytorch.reformer_pytorch import Reformer, ReformerLM, LSHSelfAttention

# 定义函数,用于将张量填充到指定的倍数
def pad_to_multiple(tensor, seqlen, multiple, dim=-1):
    # 计算倍数
    m = seqlen / multiple
    # 如果是整数倍则直接返回张量
    if m.is_integer():
        return tensor
    # 计算需要填充的长度
    remainder = math.ceil(m) * multiple - seqlen
    # 计算填充的偏移量
    pad_offset = (0,) * (-1 - dim) * 2
    # 对张量进行填充
    return F.pad(tensor, (*pad_offset, 0, remainder), value=0)

# 定义自动填充器类
class Autopadder(nn.Module):
    def __init__(self, net):
        super().__init__()
        # 检查输入的网络类型是否符合要求
        assert isinstance(net, (LSHSelfAttention, Reformer, ReformerLM)), 'only modules LSHSelfAttention, Reformer, ReformerLM accepted'
        self.net = net

        # 获取 Reformer 对象
        reformer = net.reformer if isinstance(net, ReformerLM) else net
        # 根据网络类型确定填充的维度
        self.pad_dim = -1 if isinstance(net, ReformerLM) else -2

        # 获取 Reformer 的参数
        self.bucket_size = reformer.bucket_size
        self.num_mem_kv = reformer.num_mem_kv
        self.full_attn_thres = reformer.full_attn_thres

    def forward(self, x, **kwargs):
        # 获取输入张量的形状信息
        b, t, m, device = *x.shape[:2], self.num_mem_kv, x.device

        # 获取关键信息和输入掩码
        keys = kwargs.get('keys')
        input_mask = kwargs.get('input_mask')
        input_attn_mask = kwargs.get('input_attn_mask')

        # 计算关键信息的长度
        k_len = 0 if keys is None else keys.shape[1]
        # 计算序列长度
        seqlen = t + m + k_len

        # 如果序列长度超过全局注意力阈值
        if seqlen > self.full_attn_thres:
            # 如果输入掩码为空,则创建全为 True 的掩码
            if input_mask is None:
                input_mask = torch.full((b, t), True, device=x.device, dtype=torch.bool)

            # 对输入张量进行填充
            x = pad_to_multiple(x, seqlen, self.bucket_size * 2, dim=self.pad_dim)

            # 如果输入掩码不为空,则对其进行填充
            if input_mask is not None:
                new_mask = F.pad(input_mask, (0, x.shape[1] - input_mask.shape[1]), value=False)
                kwargs.update(input_mask=new_mask)

            # 如果输入注意力掩码不为空,则对其进行填充
            if input_attn_mask is not None:
                offset = x.shape[1] - input_attn_mask.shape[1]
                new_mask = F.pad(input_attn_mask, (0, offset, 0, offset), value=False)
                kwargs.update(input_attn_mask=new_mask)

        # 对输入进行网络前向传播
        out = self.net(x, **kwargs)
        # 返回前 t 个时间步的输出
        return out[:, 0:t]

.\lucidrains\reformer-pytorch\reformer_pytorch\generative_tools.py

# 导入必要的库
from functools import partial
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from reformer_pytorch.reformer_pytorch import ReformerLM
from reformer_pytorch.autopadder import Autopadder

# 定义函数用于根据概率阈值选择最高概率的元素
def top_p(logits, thres = 0.9):
    # 对logits进行降序排序
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    # 计算累积概率
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # 根据阈值确定要移除的元素
    sorted_indices_to_remove = cum_probs > (1 - thres)
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = 0

    # 将超过阈值的元素设置为负无穷
    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# 定义函数用于根据概率阈值选择前k个元素
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个包装类,用于训练模型
class TrainingWrapper(nn.Module):
    def __init__(self, net, ignore_index = -100, pad_value = 0):
        super().__init__()
        assert isinstance(net, ReformerLM), 'generative trainer wrapper can only accept ReformerLM class'
        self.pad_value = pad_value
        self.ignore_index = ignore_index

        self.net = Autopadder(net)
        self.max_seq_len = net.max_seq_len

    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        was_training = self.net.training
        num_dims = len(start_tokens.shape)

        if num_dims == 1:
            start_tokens = start_tokens[None, :]

        b, t = start_tokens.shape

        self.net.eval()
        out = start_tokens
        input_mask = kwargs.pop('input_mask', None)

        if input_mask is None:
            input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]
            input_mask = input_mask[:, -self.max_seq_len:]

            logits = self.net(x, input_mask=input_mask, **kwargs)[:, -1, :]
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)
            probs = F.softmax(filtered_logits / temperature, dim=-1)
            sample = torch.multinomial(probs, 1)

            out = torch.cat((out, sample), dim=-1)
            input_mask = F.pad(input_mask, (0, 1), value=True)

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]

        if num_dims == 1:
            out = out.squeeze(0)

        self.net.train(was_training)
        return out

    def forward(self, x, return_loss = False, **kwargs):
        pad = partial(pad_sequence, batch_first = True, padding_value = self.pad_value)

        if not return_loss:
            if not isinstance(x, torch.Tensor):
                x = pad(x)
            return self.net(x, **kwargs)

        if isinstance(x, torch.Tensor):
            xi = x[:, :-1]
            xo = x[:, 1:]
        else:
            xi = pad(list(map(lambda t: t[:-1], x)))
            xo = pad(list(map(lambda t: t[1:], x)))

        out = self.net(xi, **kwargs)

        loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
        return loss

.\lucidrains\reformer-pytorch\reformer_pytorch\recorder.py

# 导入需要的模块
from torch import nn
from reformer_pytorch.reformer_pytorch import LSHAttention, LSHSelfAttention
from collections import defaultdict

# 定义 Recorder 类,继承自 nn.Module
class Recorder(nn.Module):
    # 初始化函数
    def __init__(self, net):
        super().__init__()
        self.iter = 0
        self.recordings = defaultdict(list)  # 使用 defaultdict 创建一个空列表的字典
        self.net = net
        self.on = True
        self.ejected = False

    # 弹出函数
    def eject(self):
        self.ejected = True
        self.clear()
        self.unwire()
        return self.net

    # 连接函数
    def wire(self):
        # 遍历网络中的模块,如果是 LSHAttention 类型,则设置 _return_attn 为 True
        for module in self.net.modules():
            if isinstance(module, LSHAttention):
                module._return_attn = True
            # 如果是 LSHSelfAttention 类型,则设置 callback 为 self.record 函数
            if isinstance(module, LSHSelfAttention):
                module.callback = self.record

    # 断开连接函数
    def unwire(self):
        # 遍历网络中的模块,如果是 LSHAttention 类型,则设置 _return_attn 为 False
        for module in self.net.modules():
            if isinstance(module, LSHAttention):
                module._return_attn = False
            # 如果是 LSHSelfAttention 类型,则设置 callback 为 None
            if isinstance(module, LSHSelfAttention):
                module.callback = None

    # 打开记录功能
    def turn_on(self):
        self.on = True

    # 关闭记录功能
    def turn_off(self):
        self.on = False

    # 清空记录
    def clear(self):
        del self.recordings
        self.recordings = defaultdict(list)  # 使用 defaultdict 创建一个空列表的字典
        self.iter = 0

    # 记录函数
    def record(self, attn, buckets):
        if not self.on: return
        data = {'attn': attn.detach().cpu(), 'buckets': buckets.detach().cpu()}
        self.recordings[self.iter].append(data)

    # 前向传播函数
    def forward(self, x, **kwargs):
        assert not self.ejected, 'Recorder has already been ejected and disposed'
        if self.on:
            self.wire()

        out = self.net(x, **kwargs)

        self.iter += 1
        self.unwire()
        return out

.\lucidrains\reformer-pytorch\reformer_pytorch\reformer_enc_dec.py

# 导入 re 模块,用于正则表达式操作
import re
# 从 torch 模块中导入 nn 类
from torch import nn
# 从 reformer_pytorch 模块中导入 ReformerLM 类
from reformer_pytorch.reformer_pytorch import ReformerLM
# 从 reformer_pytorch 模块中导入 TrainingWrapper 类
from reformer_pytorch.generative_tools import TrainingWrapper

# 定义编码器前缀
ENC_PREFIX = 'enc_'
# 定义解码器前缀
DEC_PREFIX = 'dec_'

# 根据条件将字典按键分组
def group_dict_by_key(cond, d):
    return_val = [dict(),dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val,)

# 判断字符串是否以指定前缀开头
def string_begins_with(prefix, str):
    return bool(re.match(f'^{prefix}', str))

# 根据键前缀将字典分组
def group_by_key_prefix(prefix, d):
    return group_dict_by_key(lambda x: string_begins_with(prefix, x), d)

# 根据键前缀并移除前缀将字典分组
def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(lambda x: string_begins_with(prefix, x), d)
    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))
    return kwargs_without_prefix, kwargs

# 提取编码器和解码器的关键字参数
def extract_enc_dec_kwargs(kwargs):
    enc_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(ENC_PREFIX, kwargs)
    dec_kwargs, kwargs = group_by_key_prefix_and_remove_prefix(DEC_PREFIX, kwargs)
    return enc_kwargs, dec_kwargs, kwargs

# 提取并设置编码器和解码器的关键字参数
def extract_and_set_enc_dec_kwargs(kwargs):
    enc_kwargs, dec_kwargs, kwargs = extract_enc_dec_kwargs(kwargs)
    if 'input_mask' in enc_kwargs:
        dec_kwargs.setdefault('context_mask', enc_kwargs['input_mask'])
    return enc_kwargs, dec_kwargs, kwargs

# 定义 ReformerEncDec 类,继承自 nn.Module 类
class ReformerEncDec(nn.Module):
    def __init__(self, dim, ignore_index = 0, pad_value = 0, **kwargs):
        super().__init__()
        enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)
        
        # 断言不能手动设置返回嵌入标志
        assert 'return_embedding' not in enc_kwargs, 'you cannot manually set the return embeddings flag for the encoder'
        # 断言必须为编码器和解码器设置维度
        assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'

        # 设置编码器和解码器的维度
        enc_kwargs['dim'] = dec_kwargs['dim'] = dim
        enc_kwargs['return_embeddings'] = True
        dec_kwargs['causal'] = True

        # 设置编码器和解码器的 bucket_size
        enc_kwargs.setdefault('bucket_size', 64)
        dec_kwargs.setdefault('bucket_size', enc_kwargs['bucket_size'] * 2)

        # 创建 ReformerLM 编码器和解码器对象
        enc = ReformerLM(**enc_kwargs)
        dec = ReformerLM(**dec_kwargs)

        # 使用 TrainingWrapper 封装编码器和解码器对象
        self.enc = TrainingWrapper(enc, ignore_index = ignore_index, pad_value = pad_value)
        self.dec = TrainingWrapper(dec, ignore_index = ignore_index, pad_value = pad_value)

    # 生成序列
    def generate(self, seq_in, seq_out_start, seq_len, **kwargs):
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        enc_keys = self.enc(seq_in, **enc_kwargs)
        return self.dec.generate(seq_out_start, seq_len, keys = enc_keys, **{**dec_kwargs, **kwargs})

    # 前向传播
    def forward(self, seq_in, seq_out, return_loss = False, **kwargs):
        enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
        enc_keys = self.enc(seq_in, **enc_kwargs)
        return self.dec(seq_out, return_loss = return_loss, keys = enc_keys, **dec_kwargs)

.\lucidrains\reformer-pytorch\reformer_pytorch\reformer_pytorch.py

# 导入数学库
import math
# 导入 PyTorch 库
import torch
import torch.nn as nn
# 从 torch.nn 模块导入 Identity 类
from torch.nn import Identity
# 导入 torch.nn.functional 模块
import torch.nn.functional as F
# 从 torch.autograd 模块导入 Function 类
from torch.autograd import Function
# 从 functools 模块导入 partial、reduce、wraps 函数
from functools import partial, reduce, wraps
# 从 itertools 模块导入 chain 函数
from itertools import chain
# 从 operator 模块导入 mul 函数
from operator import mul

# 导入自定义模块
from local_attention import LocalAttention
from axial_positional_embedding import AxialPositionalEmbedding
from product_key_memory import PKM
from reformer_pytorch.reversible import ReversibleSequence

# 导入 einops 库
from einops import rearrange, repeat

# 常量定义

# 用于自注意力机制的特殊值,用于半精度计算
TOKEN_SELF_ATTN_VALUE = -5e4

# 辅助函数

# 判断变量是否存在
def exists(val):
    return val is not None

# 对两个张量进行排序,并返回排序后的值和对应的张量
def sort_key_val(t1, t2, dim=-1):
    values, indices = t1.sort(dim=dim)
    t2 = t2.expand_as(t1)
    return values, t2.gather(dim, indices)

# 在指定维度上对张量进行批量索引选择
def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))

# 对输入进行分块处理
def process_inputs_chunk(fn, chunks=1, dim=0):
    def inner_fn(*args, **kwargs):
        keys, values, len_args = kwargs.keys(), kwargs.values(), len(args)
        chunked_args = list(zip(*map(lambda x: x.chunk(chunks, dim=dim), list(args) + list(values))))
        all_args = map(lambda x: (x[:len_args], dict(zip(keys, x[len_args:]))), chunked_args)
        outputs = [fn(*c_args, **c_kwargs) for c_args, c_kwargs in all_args]
        return tuple(map(lambda x: torch.cat(x, dim=dim), zip(*outputs)))
    return inner_fn

# 对张量进行分块求和
def chunked_sum(tensor, chunks=1):
    *orig_size, last_dim = tensor.shape
    tensor = tensor.reshape(-1, last_dim)
    summed_tensors = [c.sum(dim=-1) for c in tensor.chunk(chunks, dim=0)]
    return torch.cat(summed_tensors, dim=0).reshape(orig_size)

# 返回默认值
def default(val, default_val):
    return default_val if val is None else val

# 将输入转换为元组
def cast_tuple(x):
    return x if isinstance(x, tuple) else (x,)

# 返回张量的最大负值
def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

# 缓存函��的计算结果
def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, **kwargs):
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

# 缓存方法的计算结果
def cache_method_decorator(cache_attr, cache_namespace, reexecute=False):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, key_namespace=None, fetch=False, set_cache=True, **kwargs):
            namespace_str = str(default(key_namespace, ''))
            _cache = getattr(self, cache_attr)
            _keyname = f'{cache_namespace}:{namespace_str}'

            if fetch:
                val = _cache[_keyname]
                if reexecute:
                    fn(self, *args, **kwargs)
            else:
                val = fn(self, *args, **kwargs)
                if set_cache:
                    setattr(self, cache_attr, {**_cache, **{_keyname: val}})
            return val
        return wrapper
    return inner_fn

# 在指定维度上扩展张量的维度
def expand_dim(dim, k, t):
    t = t.unsqueeze(dim)
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)

# 合并张量的维度
def merge_dims(ind_from, ind_to, tensor):
    shape = list(tensor.shape)
    arr_slice = slice(ind_from, ind_to + 1)
    shape[arr_slice] = [reduce(mul, shape[arr_slice])]
    return tensor.reshape(*shape)

# 在指定维度上将张量拆分为两部分
def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]

# 辅助类

# 始终返回固定值的模块
class Always(nn.Module):
    def __init__(self, val):
        super().__init__()
        self.val = val

    def forward(self, *args, **kwargs):
        return self.val

# 矩阵乘法模块
class MatrixMultiply(nn.Module):
    def __init__(self, tensor, transpose=False, normalize=False):
        super().__init__()
        self.tensor = tensor
        self.transpose = transpose
        self.normalize = normalize
    # 定义一个前向传播函数,接受输入张量 x
    def forward(self, x):
        # 将类中的张量赋值给变量 tensor
        tensor = self.tensor
        # 如果需要进行标准化操作
        if self.normalize:
            # 对张量进行标准化操作,沿着最后一个维度进行标准化
            tensor = F.normalize(tensor, dim=-1)
        # 如果需要进行转置操作
        if self.transpose:
            # 对张量进行转置操作
            tensor = tensor.t()
        # 返回输入张量与处理后的张量的矩阵乘法结果
        return x @ tensor
# 定义 ReZero 类,继承自 nn.Module
class ReZero(nn.Module):
    # 初始化函数,接受一个函数 fn 作为参数
    def __init__(self, fn):
        super().__init__()
        # 创建一个可学习的参数 g,初始化为零
        self.g = nn.Parameter(torch.zeros(1))
        # 将传入的函数 fn 赋值给 self.fn
        self.fn = fn

    # 前向传播函数,接受输入 x 和其他关键字参数
    def forward(self, x, **kwargs):
        # 返回经过函数 fn 处理后的结果乘以参数 g
        return self.fn(x, **kwargs) * self.g

# 定义 ScaleNorm 类,继承自 nn.Module
class ScaleNorm(nn.Module):
    # 初始化函数,接受维度 dim 和一个小数 eps 作为参数
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        # 创建一个可学习的参数 g,初始化为一
        self.g = nn.Parameter(torch.ones(1))
        # 将传入的 eps 赋值给 self.eps
        self.eps = eps

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 计算 x 在指定维度上的范数,并限制最小值为 eps
        n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
        # 返回 x 除以范数后乘以参数 g 的结果
        return x / n * self.g

# 定义 PreNorm 类,继承自 nn.Module
class PreNorm(nn.Module):
    # 初始化函数,接受一个规范化类 norm_class、维度 dim 和一个函数 fn 作为参数
    def __init__(self, norm_class, dim, fn):
        super().__init__()
        # 创建一个 norm_class 类型的规范化对象,并赋值给 self.norm
        self.norm = norm_class(dim)
        # 将传入的函数 fn 赋值给 self.fn
        self.fn = fn

    # 前向传播函数,接受输入 x 和其他关键字参数
    def forward(self, x, **kwargs):
        # 对输入 x 进行规范化
        x = self.norm(x)
        # 返回经过函数 fn 处理后的结果
        return self.fn(x, **kwargs)

# 定义 Chunk 类,继承自 nn.Module
class Chunk(nn.Module):
    # 初始化函数,接受块数 chunks、函数 fn 和沿着的维度 along_dim 作为参数
    def __init__(self, chunks, fn, along_dim=-1):
        super().__init__()
        # 将 along_dim 赋值给 self.dim
        self.dim = along_dim
        # 将 chunks 和 fn 赋值给 self.chunks 和 self.fn
        self.chunks = chunks
        self.fn = fn

    # 前向传播函数,接受输入 x 和其他关键字参数
    def forward(self, x, **kwargs):
        # 如果 chunks 等于 1,则直接返回经过函数 fn 处理后的结果
        if self.chunks == 1:
            return self.fn(x, **kwargs)
        # 将输入 x 沿着维度 self.dim 切分成多个块
        chunks = x.chunk(self.chunks, dim=self.dim)
        # 对每个块应用函数 fn,并在指定维度上拼接结果
        return torch.cat([self.fn(c, **kwargs) for c in chunks], dim=self.dim)

# LSH attention 类,实现了论文中描述的 LSH 注意力机制
class LSHAttention(nn.Module):
    # 初始化函数,接受多个参数设置
    def __init__( self,
                  dropout=0.,
                  bucket_size=64,
                  n_hashes=8,
                  causal=False,
                  allow_duplicate_attention=True,
                  attend_across_buckets=True,
                  rehash_each_round=True,
                  drop_for_hash_rate=0.0,
                  random_rotations_per_head=False,
                  return_attn=False):
        super().__init__()
        # 如果 dropout 大于等于 1,则抛出异常
        if dropout >= 1.0:
            raise ValueError('Dropout rates must be lower than 1.')

        # 创建一个 dropout 层,用于在训练时随机丢弃部分数据
        self.dropout = nn.Dropout(dropout)
        self.dropout_for_hash = nn.Dropout(drop_for_hash_rate)

        # 确保每轮重新哈希或允许重复注意力的设置
        assert rehash_each_round or allow_duplicate_attention, (
            'The setting {allow_duplicate_attention=False, rehash_each_round=False}'
            ' is not implemented.')

        # 设置是否是因果关系
        self.causal = causal
        self.bucket_size = bucket_size

        self.n_hashes = n_hashes

        self._allow_duplicate_attention = allow_duplicate_attention
        self._attend_across_buckets = attend_across_buckets
        self._rehash_each_round = rehash_each_round
        self._random_rotations_per_head = random_rotations_per_head

        # 是否返回注意力矩阵
        self._return_attn = return_attn

        # 用于缓存可逆网络的桶,作者报告这样可以使 Reformer 在深度上工作
        self._cache = {}

    # 缓存方法装饰器,用于缓存 buckets
    @cache_method_decorator('_cache', 'buckets', reexecute=True)
    # 对输入的向量进行哈希处理,将其映射到指定数量的桶中
    def hash_vectors(self, n_buckets, vecs):
        # 获取输入向量的批量大小
        batch_size = vecs.shape[0]
        # 获取输入向量所在设备
        device = vecs.device

        # 参考论文 https://arxiv.org/pdf/1509.02897.pdf
        # 为每一轮哈希采样不同的随机旋转,以减少哈希失配的概率
        assert n_buckets % 2 == 0

        rot_size = n_buckets

        rotations_shape = (
            batch_size if self._random_rotations_per_head else 1,
            vecs.shape[-1],
            self.n_hashes if self._rehash_each_round else 1,
            rot_size // 2)

        # 生成随机旋转矩阵
        random_rotations = torch.randn(rotations_shape, dtype=vecs.dtype, device=device).expand(batch_size, -1, -1, -1)

        # 对输入向量进行哈希前的丢弃处理
        dropped_vecs = self.dropout_for_hash(vecs)
        # 对丢弃后的向量进行旋转操作
        rotated_vecs = torch.einsum('btf,bfhi->bhti', dropped_vecs, random_rotations)

        if self._rehash_each_round:
            # 如果每轮都重新哈希,则将旋转后的向量进行拼接
            rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
            # 获取每个向量对应的桶索引
            buckets = torch.argmax(rotated_vecs, dim=-1)
        else:
            rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
            # 在这种配置下,将每个项目映射到前 self.n_hashes 个桶中
            rotated_vecs = torch.squeeze(rotated_vecs, 1)
            bucket_range = torch.arange(rotated_vecs.shape[-1], device=device)
            bucket_range = torch.reshape(bucket_range, (1, -1))
            bucket_range = bucket_range.expand_as(rotated_vecs)

            # 对旋转后的向量进行排序,获取对应的桶索引
            _, buckets = sort_key_val(rotated_vecs, bucket_range, dim=-1)
            # 调整桶索引的维度
            buckets = buckets[... , -self.n_hashes:].transpose(1, 2)

        # 每个哈希轮次的桶索引现在是 (self.n_hashes, seq_len) 的形状。接下来添加偏移量,以避免不同哈希轮次的桶号重叠
        offsets = torch.arange(self.n_hashes, device=device)
        offsets = torch.reshape(offsets * n_buckets, (1, -1, 1))
        buckets = torch.reshape(buckets + offsets, (batch_size, -1,))
        # 返回最终的桶索引
        return buckets
# 定义全连接的注意力机制类
class FullQKAttention(nn.Module):
    def __init__(self, causal = False, dropout = 0.):
        super().__init__()
        self.causal = causal
        self.dropout = nn.Dropout(dropout)

    def forward(self, qk, v, query_len = None, input_mask = None, input_attn_mask = None, **kwargs):
        b, seq_len, dim = qk.shape
        query_len = default(query_len, seq_len)
        t = query_len

        q = qk[:, 0:query_len]
        qk = F.normalize(qk, 2, dim=-1).type_as(q)

        dot = torch.einsum('bie,bje->bij', q, qk) * (dim ** -0.5)

        # qk attention requires tokens not attend to self
        i = torch.arange(t)
        dot[:, i, i] = TOKEN_SELF_ATTN_VALUE
        masked_value = max_neg_value(dot)

        # Input mask for padding in variable lengthed sequences
        if input_mask is not None:
            mask = input_mask[:, 0:query_len, None] * input_mask[:, None, :]
            mask = F.pad(mask, (0, seq_len - mask.shape[-1]), value=True)
            dot.masked_fill_(~mask, masked_value)

        # Mask for post qk attention logits of the input sequence
        if input_attn_mask is not None:
            input_attn_mask = F.pad(input_attn_mask, (0, seq_len - input_attn_mask.shape[-1]), value=True)
            dot.masked_fill_(~input_attn_mask, masked_value)

        if self.causal:
            i, j = torch.triu_indices(t, t, 1)
            dot[:, i, j] = masked_value

        dot = dot.softmax(dim=-1)
        dot = self.dropout(dot)

        out = torch.einsum('bij,bje->bie', dot, v)

        return out, dot, torch.empty(0)

# 共享的 qk 注意力机制,使用全局或 LSH 注意力机制
class LSHSelfAttention(nn.Module):
    def __init__(self, dim, heads = 8, bucket_size = 64, n_hashes = 8, causal = False, dim_head = None, attn_chunks = 1, random_rotations_per_head = False, attend_across_buckets = True, allow_duplicate_attention = True, num_mem_kv = 0, one_value_head = False, use_full_attn = False, full_attn_thres = None, return_attn = False, post_attn_dropout = 0., dropout = 0., n_local_attn_heads = 0, **kwargs):
        super().__init__()
        assert dim_head or (dim % heads) == 0, 'dimensions must be divisible by number of heads'
        assert n_local_attn_heads < heads, 'local attention heads must be less than number of heads'

        dim_head = default(dim_head, dim // heads)
        dim_heads = dim_head * heads

        self.dim = dim
        self.heads = heads
        self.dim_head = dim_head
        self.attn_chunks = default(attn_chunks, 1)

        self.v_head_repeats = (heads if one_value_head else 1)
        v_dim = dim_heads // self.v_head_repeats

        self.toqk = nn.Linear(dim, dim_heads, bias = False)
        self.tov = nn.Linear(dim, v_dim, bias = False)
        self.to_out = nn.Linear(dim_heads, dim)

        self.bucket_size = bucket_size
        self.lsh_attn = LSHAttention(bucket_size=bucket_size, n_hashes=n_hashes, causal=causal, random_rotations_per_head=random_rotations_per_head, attend_across_buckets = attend_across_buckets,  allow_duplicate_attention = allow_duplicate_attention, return_attn = return_attn, dropout = dropout, **kwargs)
        self.full_attn = FullQKAttention(causal=causal, dropout=dropout)
        self.post_attn_dropout = nn.Dropout(post_attn_dropout)

        self.use_full_attn = use_full_attn
        self.full_attn_thres = default(full_attn_thres, bucket_size)

        self.num_mem_kv = num_mem_kv
        self.mem_kv = nn.Parameter(torch.randn(1, num_mem_kv, dim, requires_grad=True)) if num_mem_kv > 0 else None

        self.n_local_attn_heads = n_local_attn_heads
        self.local_attn = LocalAttention(window_size=bucket_size * 2, causal=causal, dropout=dropout, shared_qk=True, look_forward=(1 if not causal else 0))

        self.callback = None
    # 定义前向传播函数,接受输入 x 和其他可选参数
    def forward(self, x, keys = None, input_mask = None, input_attn_mask = None, context_mask = None, pos_emb = None, **kwargs):
        # 获取输入 x 的设备和数据类型
        device, dtype = x.device, x.dtype
        # 获取输入 x 的形状信息
        b, t, e, h, dh, m, l_h = *x.shape, self.heads, self.dim_head, self.num_mem_kv, self.n_local_attn_heads

        # 初始化记忆键值对
        mem_kv = default(self.mem_kv, torch.empty(b, 0, e, dtype=dtype, device=device))
        mem = mem_kv.expand(b, m, -1)

        # 初始化键
        keys = default(keys, torch.empty(b, 0, e, dtype=dtype, device=device))
        c = keys.shape[1]

        # 计算键值对的长度
        kv_len = t + m + c
        # 判断是否使用全局注意力
        use_full_attn = self.use_full_attn or kv_len <= self.full_attn_thres

        # 将输入 x、记忆和键连接起来
        x = torch.cat((x, mem, keys), dim=1)
        # 将输入 x 转换为查询和键
        qk = self.toqk(x)
        # 将输入 x 转换为值
        v = self.tov(x)
        # 复制值以匹配头数
        v = v.repeat(1, 1, self.v_head_repeats)

        # 定义合并头部的函数
        def merge_heads(v):
            return v.view(b, kv_len, h, -1).transpose(1, 2)

        # 定义分割头部的函数
        def split_heads(v):
            return v.view(b, h, t, -1).transpose(1, 2).contiguous()

        # 合并批次和头部维度
        merge_batch_and_heads = partial(merge_dims, 0, 1)

        # 对查询和键值对进行头部合并
        qk, v = map(merge_heads, (qk, v))

        # 判断是否有局部注意力
        has_local = l_h > 0
        lsh_h = h - l_h

        # 分割索引函数
        split_index_fn = partial(split_at_index, 1, l_h)
        (lqk, qk), (lv, v) = map(split_index_fn, (qk, v))
        lqk, qk, lv, v = map(merge_batch_and_heads, (lqk, qk, lv, v))

        # 初始化掩码字典
        masks = {}
        # 如果存在输入掩码或上下文掩码
        if input_mask is not None or context_mask is not None:
            default_mask = torch.tensor([True], device=device)
            i_mask = default(input_mask, default_mask.expand(b, t))
            m_mask = default_mask.expand(b, m)
            c_mask = default(context_mask, default_mask.expand(b, c))
            mask = torch.cat((i_mask, m_mask, c_mask), dim=1)
            mask = merge_batch_and_heads(expand_dim(1, lsh_h, mask))
            masks['input_mask'] = mask

        # 如果存在输入注意力掩码
        if input_attn_mask is not None:
            input_attn_mask = merge_batch_and_heads(expand_dim(1, lsh_h, input_attn_mask))
            masks['input_attn_mask'] = input_attn_mask

        # 根据是否使用全局注意力选择不同的注意力函数
        attn_fn = self.lsh_attn if not use_full_attn else self.full_attn
        partial_attn_fn = partial(attn_fn, query_len = t, pos_emb = pos_emb, **kwargs)
        attn_fn_in_chunks = process_inputs_chunk(partial_attn_fn, chunks = self.attn_chunks)

        # 执行注意力函数
        out, attn, buckets = attn_fn_in_chunks(qk, v, **masks)

        # 如果存在回调函数,则执行回调
        if self.callback is not None:
            self.callback(attn.reshape(b, lsh_h, t, -1), buckets.reshape(b, lsh_h, -1))

        # 如果存在局部注意力
        if has_local:
            lqk, lv = lqk[:, :t], lv[:, :t]
            local_out = self.local_attn(lqk, lqk, lv, input_mask=input_mask)
            local_out = local_out.reshape(b, l_h, t, -1)
            out = out.reshape(b, lsh_h, t, -1)
            out = torch.cat((local_out, out), dim=1)

        # 分割头部并重塑输出
        out = split_heads(out).view(b, t, -1)
        out = self.to_out(out)
        return self.post_attn_dropout(out)
# 定义 GELU 激活函数类,继承自 nn.Module
class GELU_(nn.Module):
    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 使用 GELU 激活函数计算输出
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))

# 如果 nn 模块中存在 GELU 类,则使用 nn.GELU,否则使用自定义的 GELU_ 类
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# 定义前馈神经网络类 FeedForward,继承自 nn.Module
class FeedForward(nn.Module):
    # 初始化函数,接受维度 dim、倍数 mult、dropout 概率、激活函数 activation 和 glu 标志
    def __init__(self, dim, mult=4, dropout=0., activation=None, glu=False):
        super().__init__()
        # 设置激活函数为默认值 GELU
        activation = default(activation, GELU)

        self.glu = glu
        # 第一层全连接层,输入维度为 dim,输出维度为 dim * mult * (2 if glu else 1)
        self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
        # 激活函数层
        self.act = activation()
        # Dropout 层
        self.dropout = nn.Dropout(dropout)
        # 第二层全连接层,输入维度为 dim * mult,输出维度为 dim
        self.w2 = nn.Linear(dim * mult, dim)

    # 前向传播函数,接受输入 x 和其他参数
    def forward(self, x, **kwargs):
        # 如果不使用 glu
        if not self.glu:
            # 进行第一层全连接层和激活函数的计算
            x = self.w1(x)
            x = self.act(x)
        else:
            # 如果使用 glu,进行特殊处理
            x, v = self.w1(x).chunk(2, dim=-1)
            x = self.act(x) * v

        # Dropout
        x = self.dropout(x)
        # 第二层全连接层计算结果
        x = self.w2(x)
        return x

# 绝对位置嵌入类,继承自 nn.Module
class AbsolutePositionalEmbedding(nn.Module):
    # 初始化函数,接受维度 dim 和最大序列长度 max_seq_len
    def __init__(self, dim, max_seq_len):
        super().__init__()
        # 创建 Embedding 层,输入维度为最大序列长度,输出维度为 dim
        self.emb = nn.Embedding(max_seq_len, dim)

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 生成序列长度的张量 t
        t = torch.arange(x.shape[1], device=x.device)
        # 返回位置嵌入结果
        return self.emb(t)

# 固定位置嵌入类,继承自 nn.Module
class FixedPositionalEmbedding(nn.Module):
    # 初始化函数,接受维度 dim
    def __init__(self, dim):
        super().__init__()
        # 计算频率
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        # 将频率作为缓冲区
        self.register_buffer('inv_freq', inv_freq)

    # 前向传播函数,接受输入 x 和序列维度 seq_dim
    def forward(self, x, seq_dim=1):
        # 生成序列长度的张量 t
        t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
        # 计算正弦和余弦位置嵌入
        sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
        return emb[None, :, :].type_as(x)

# 旋转位置嵌入辅助函数,用于旋转每两个元素
def rotate_every_two(x):
    x = rearrange(x, '... (d j) -> ... d j', j=2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, '... d j -> ... (d j)')

# 应用旋转位置嵌入函数,接受查询键 qk 和正弦位置 sinu_pos
def apply_rotary_pos_emb(qk, sinu_pos):
    sinu_pos = sinu_pos.type(qk.dtype)
    sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j=2)
    sin, cos = sinu_pos.unbind(dim=-2)
    sin, cos = map(lambda t: repeat(t, 'n d -> n (d j)', j=2), (sin, cos))
    seq_len = sin.shape[0]
    qk, qk_pass = qk[:, :seq_len], qk[:, seq_len:]
    qk = (qk * cos) + (rotate_every_two(qk) * sin)
    return torch.cat((qk, qk_pass), dim=1)

# Reformer 语言模型类,继承自 nn.Module
class Reformer(nn.Module):
    # 初始化函数,设置模型参数
    def __init__(self, dim, depth, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_activation = None, ff_mult = 4, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
        # 调用父类的初始化函数
        super().__init__()
        # 设置模型的维度和深度
        self.dim = dim
        self.depth = depth

        # 设置桶的大小和记忆键值对的数量
        self.bucket_size = bucket_size
        self.num_mem_kv = num_mem_kv

        # 设置全局注意力的阈值
        self.full_attn_thres = full_attn_thres

        # 定义获取注意力和前馈网络的函数
        get_attn = lambda: LSHSelfAttention(dim, heads, bucket_size, n_hashes, causal = causal, dim_head = dim_head, dropout = lsh_dropout, post_attn_dropout = post_attn_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads)
        get_ff = lambda: Chunk(ff_chunks, FeedForward(dim, dropout = ff_dropout, activation = ff_activation, mult = ff_mult, glu = ff_glu), along_dim = -2)
        get_pkm = lambda: PKM(dim, num_keys = pkm_num_keys)

        # 如果权重共享为真,则对获取注意力和前馈网络的函数进行缓存
        if weight_tie:
            get_attn, get_ff, get_pkm = map(cache_fn, (get_attn, get_ff, get_pkm))

        # 初始化块列表
        blocks = []

        # 根据是否使用标准化类型,选择不同的标准化函数
        norm_type = ScaleNorm if use_scale_norm else nn.LayerNorm

        # 根据是否使用 ReZero,选择不同的残差函数
        residual_fn_wrapper = ReZero if use_rezero else partial(PreNorm, norm_type, dim)

        # 循环构建深度个块
        for ind in range(depth):
            layer_num = ind + 1
            use_pkm = layer_num in cast_tuple(pkm_layers)
            parallel_net = None

            # 获取注意力和前馈网络
            attn = get_attn()

            if use_pkm:
                parallel_net = get_pkm()
            else:
                parallel_net = get_ff()

            f = residual_fn_wrapper(attn)
            g = residual_fn_wrapper(parallel_net)

            blocks.append(nn.ModuleList([f, g]))

        # 构建可逆序列
        self.layers = ReversibleSequence(nn.ModuleList(blocks), layer_dropout = layer_dropout, reverse_thres = reverse_thres, send_signal = True)

    # 前向传播函数
    def forward(self, x, **kwargs):
        # 在最后一个维度上拼接输入张量
        x = torch.cat([x, x], dim = -1)
        # 使用可逆序列进行前向传播
        x = self.layers(x, **kwargs)
        # 将结果张量按最后一个维度分块,取均值
        return torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
class ReformerLM(nn.Module):
    # 定义 ReformerLM 类,继承自 nn.Module
    def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = 64, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, axial_position_emb = False, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
        # 初始化函数,接受多个参数
        super().__init__()
        # 调用父类的初始化函数

        emb_dim = default(emb_dim, dim)
        # 如果 emb_dim 为 None,则使用 dim

        self.max_seq_len = max_seq_len
        # 设置最大序列长度

        self.token_emb = nn.Embedding(num_tokens, emb_dim)
        # 创建一个嵌入层,用于将输入的 token 转换为向量表示

        self.to_model_dim = Identity() if emb_dim == dim else nn.Linear(emb_dim, dim)
        # 如果 emb_dim 等于 dim,则使用 Identity(),否则使用线性层将 emb_dim 转换为 dim

        self.pos_emb = Always(0)
        self.layer_pos_emb = Always(None)
        # 初始化位置编码

        if axial_position_emb:
            # 如果启用轴向位置编码
            axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size))
            # 计算轴向位置编码的形状
            self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape)
            # 创建轴向位置编码
        elif absolute_position_emb:
            # 如果启用绝对位置编码
            self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
            # 创建绝对位置编码
        elif fixed_position_emb:
            # 如果启用固定位置编码
            self.pos_emb = FixedPositionalEmbedding(emb_dim)
            # 创建固定位置编码
        else:
            self.layer_pos_emb = FixedPositionalEmbedding(dim_head)
            # 创建固定位置编码

        self.reformer = Reformer(dim, depth, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys)
        # 创建 Reformer 模型

        self.norm = nn.LayerNorm(dim)
        # 创建 LayerNorm 层

        if return_embeddings:
            self.out = Identity()
            return
            # 如果需要返回嵌入向量,则直接返回

        self.out = nn.Sequential(
            nn.Linear(dim, emb_dim) if emb_dim != dim else Identity(),
            nn.Linear(emb_dim, num_tokens) if not weight_tie_embedding else MatrixMultiply(self.token_emb.weight, transpose=True, normalize=True)
        )
        # 创建输出层,根据是否需要权重共享选择不同的操作

    def forward(self, x, **kwargs):
        # 前向传播函数
        x = self.token_emb(x)
        # 将输入的 token 转换为向量表示
        x = x + self.pos_emb(x)
        # 添加位置编码到输入向量中

        layer_pos_emb = self.layer_pos_emb(x)
        # 获取层级位置编码
        x = self.to_model_dim(x)
        # 将输入向量转换为模型维度
        x = self.reformer(x, pos_emb = layer_pos_emb, **kwargs)
        # 使用 Reformer 模型进行处理
        x = self.norm(x)
        # 对输出进行 LayerNorm 处理
        return self.out(x)
        # 返回输出结果

.\lucidrains\reformer-pytorch\reformer_pytorch\reversible.py

import torch
import torch.nn as nn
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states



# 创建一个继承自 nn.Module 的 Deterministic 类,用于记录和设置随机数生成器状态
# 参考链接:https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
class Deterministic(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.cpu_state = None
        self.cuda_in_fwd = None
        self.gpu_devices = None
        self.gpu_states = None

    def record_rng(self, *args):
        self.cpu_state = torch.get_rng_state()
        if torch.cuda._initialized:
            self.cuda_in_fwd = True
            self.gpu_devices, self.gpu_states = get_device_states(*args)

    def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
        if record_rng:
            self.record_rng(*args)

        if not set_rng:
            return self.net(*args, **kwargs)

        rng_devices = []
        if self.cuda_in_fwd:
            rng_devices = self.gpu_devices

        with torch.random.fork_rng(devices=rng_devices, enabled=True):
            torch.set_rng_state(self.cpu_state)
            if self.cuda_in_fwd:
                set_device_states(self.gpu_devices, self.gpu_states)
            return self.net(*args, **kwargs)



# 创建一个继承自 nn.Module 的 ReversibleBlock 类,用于实现可逆块
# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发
# 一旦多 GPU 工作正常,重构并向源发送 PR
class ReversibleBlock(nn.Module):
    def __init__(self, f, g, depth=None, send_signal = False):
        super().__init__()
        self.f = Deterministic(f)
        self.g = Deterministic(g)

        self.depth = depth
        self.send_signal = send_signal

    def forward(self, x, f_args = {}, g_args = {}):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1, y2 = None, None

        if self.send_signal:
            f_args['_reverse'] = g_args['_reverse'] = False
            f_args['_depth'] = g_args['_depth'] = self.depth

        with torch.no_grad():
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

        return torch.cat([y1, y2], dim=2)

    def backward_pass(self, y, dy, f_args = {}, g_args = {}):
        y1, y2 = torch.chunk(y, 2, dim=2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        if self.send_signal:
            f_args['_reverse'] = g_args['_reverse'] = True
            f_args['_depth'] = g_args['_depth'] = self.depth

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True, **g_args)
            torch.autograd.backward(gy1, dy2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dy2 + x2.grad
            del dy2
            x2.grad = None

            x = torch.cat([x1, x2.detach()], dim=2)
            dx = torch.cat([dx1, dx2], dim=2)

        return x, dx



# 创建一个继承自 nn.Module 的 IrreversibleBlock 类,用于实现不可逆块
class IrreversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = f
        self.g = g

    def forward(self, x, f_args, g_args):
        x1, x2 = torch.chunk(x, 2, dim=2)
        y1 = x1 + self.f(x2, **f_args)
        y2 = x2 + self.g(y1, **g_args)
        return torch.cat([y1, y2], dim=2)



# 创建一个继承自 Function 的 _ReversibleFunction 类,用于实现可逆函数
class _ReversibleFunction(Function):
    @staticmethod
    def forward(ctx, x, blocks, kwargs):
        ctx.kwargs = kwargs
        for block in blocks:
            x = block(x, **kwargs)
        ctx.y = x.detach()
        ctx.blocks = blocks
        return x

    @staticmethod
    # 定义一个反向传播函数,接收上下文和梯度作为参数
    def backward(ctx, dy):
        # 从上下文中获取 y 值
        y = ctx.y
        # 从上下文中获取关键字参数
        kwargs = ctx.kwargs
        # 反向遍历上下文中的块列表
        for block in ctx.blocks[::-1]:
            # 调用每个块的反向传播方法,更新 y 和 dy
            y, dy = block.backward_pass(y, dy, **kwargs)
        # 返回更新后的梯度
        return dy, None, None
# 定义一个可逆序列的神经网络模块
class ReversibleSequence(nn.Module):
    # 初始化函数,接受一些参数用于构建可逆序列
    def __init__(self, blocks, layer_dropout = 0., reverse_thres = 0, send_signal = False):
        super().__init__()
        # 设置层级丢弃率和反转阈值
        self.layer_dropout = layer_dropout
        self.reverse_thres = reverse_thres

        # 创建可逆块的模块列表,根据是否需要反转选择不同的块
        self.blocks = nn.ModuleList([ReversibleBlock(f, g, depth, send_signal) for depth, (f, g) in enumerate(blocks)])
        self.irrev_blocks = nn.ModuleList([IrreversibleBlock(f=f, g=g) for f, g in blocks])

    # 前向传播函数,接受输入和一些参数,根据是否需要反转选择不同的块进行处理
    def forward(self, x, arg_route = (True, False), **kwargs):
        # 判断是否需要反转
        reverse = x.shape[1] > self.reverse_thres
        blocks = self.blocks if reverse else self.irrev_blocks

        # 如果处于训练状态且设置了层级丢弃率
        if self.training and self.layer_dropout > 0:
            # 随机选择是否丢弃某些块
            to_drop = torch.empty(len(self.blocks)).uniform_(0, 1) < self.layer_dropout
            blocks = [block for block, drop in zip(self.blocks, to_drop) if not drop]
            blocks = self.blocks[:1] if len(blocks) == 0 else blocks

        # 根据参数路由设置不同的参数
        f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)
        block_kwargs = {'f_args': f_args, 'g_args': g_args}

        # 如果不需要反转,则依次对每个块进行处理
        if not reverse:
            for block in blocks:
                x = block(x, **block_kwargs)
            return x

        # 如果需要反转,则调用自定义的可逆函数进行处理
        return _ReversibleFunction.apply(x, blocks, block_kwargs)

.\lucidrains\reformer-pytorch\reformer_pytorch\__init__.py

# 从 reformer_pytorch 模块中导入 LSHAttention, LSHSelfAttention, Reformer, ReformerLM 类
from reformer_pytorch.reformer_pytorch import LSHAttention, LSHSelfAttention, Reformer, ReformerLM
# 从 reformer_pytorch 模块中导入 ReformerEncDec 类
from reformer_pytorch.reformer_enc_dec import ReformerEncDec
# 从 reformer_pytorch 模块中导入 Recorder 类
from reformer_pytorch.recorder import Recorder
# 从 reformer_pytorch 模块中导入 Autopadder 类
from reformer_pytorch.autopadder import Autopadder

.\lucidrains\reformer-pytorch\setup.py

# 导入设置安装包和查找包的模块
from setuptools import setup, find_packages

# 设置安装包的信息
setup(
  # 包的名称
  name = 'reformer_pytorch',
  # 查找包,排除 examples 和 pretraining 文件夹
  packages = find_packages(exclude=['examples', 'pretraining']),
  # 版本号
  version = '1.4.4',
  # 许可证
  license='MIT',
  # 描述
  description = 'Reformer, the Efficient Transformer, Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/reformer-pytorch',
  # 关键词
  keywords = ['transformers', 'attention', 'artificial intelligence'],
  # 安装依赖
  install_requires=[
    'axial-positional-embedding>=0.1.0',
    'einops',
    'local-attention',
    'product-key-memory',
    'torch'
  ],
  # 分类
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

ReLA (Rectified Linear Attention) Transformer

Implementation of a Transformer using ReLA (Rectified Linear Attention). It will also contain an attempt to combine the feedforward into the ReLA layer as memory key / values, as proposed in All Attention, suggestion made by Charles Foster.

Install

$ pip install rela-transformer

Usage

import torch
from rela_transformer import ReLATransformer

model = ReLATransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 8,
    max_seq_len = 1024,
    dim_head = 64,
    heads = 8
)

x = torch.randint(0, 20000, (1, 1024))
mask = torch.ones(1, 1024).bool()

logits = model(x, mask = mask) # (1, 1024, 20000)

Enwik8

$ python train.py

Citations

@misc{zhang2021sparse,
    title   = {Sparse Attention with Linear Units},
    author  = {Biao Zhang and Ivan Titov and Rico Sennrich},
    year    = {2021},
    eprint  = {2104.07012},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

.\lucidrains\rela-transformer\rela_transformer\autoregressive_wrapper.py

# 导入必要的库
from functools import partial
import torch
import random
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

# 定义函数,判断值是否存在
def exists(val):
    return val is not None

# 定义函数,返回值或默认值
def default(value, default):
    return value if exists(value) else default

# 定义函数,计算输入张量的对数
def log(t, eps=1e-9):
    return torch.log(t + eps)

# 定义函数,根据阈值返回前k个概率最高的logits
def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 定义一个自回归包装器类
class AutoregressiveWrapper(nn.Module):
    def __init__(self, net, ignore_index = None, pad_value = 0):
        super().__init__()        
        self.pad_value = pad_value
        self.ignore_index = default(ignore_index, pad_value)

        self.net = net
        self.max_seq_len = net.max_seq_len

    # 生成序列的方法
    @torch.no_grad()
    def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, **kwargs):
        was_training = self.net.training
        num_dims = len(start_tokens.shape)
        b, t = start_tokens.shape

        self.net.eval()
        out = start_tokens

        for _ in range(seq_len):
            x = out[:, -self.max_seq_len:]

            logits = self.net(x, **kwargs)
            logits = logits[:, -1, :]
            filtered_logits = filter_logits_fn(logits, thres = filter_thres)

            gumbel_noise = -log(-log(torch.zeros_like(filtered_logits).uniform_(0, 1)))
            sample = ((filtered_logits / temperature) + gumbel_noise).argmax(dim=-1)

            out = torch.cat((out, sample[:, None]), dim=-1)

            if eos_token is not None and (sample == eos_token).all():
                break

        out = out[:, t:]
        self.net.train(was_training)
        return out

    # 前向传播方法
    def forward(self, x, *args, **kwargs):
        inp, labels = x[:, :-1], x[:, 1:]
        out = self.net(inp, *args, **kwargs)
        loss = F.cross_entropy(out.transpose(1, 2), labels, ignore_index = self.ignore_index)
        return loss

.\lucidrains\rela-transformer\rela_transformer\rela_transformer.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange、repeat 函数
from einops import rearrange, repeat

# 定义辅助函数 exists,用于检查值是否存在
def exists(val):
    return val is not None

# 定义 GatedRMSNorm 类,继承自 nn.Module
class GatedRMSNorm(nn.Module):
    def __init__(
        self,
        dim,
        eps = 1e-8
    ):
        super().__init__()
        # 初始化缩放因子 scale
        self.scale = dim ** -0.5
        # 初始化 eps
        self.eps = eps
        # 初始化可学习参数 w 和 g
        self.w = nn.Parameter(torch.ones(dim))
        self.g = nn.Parameter(torch.ones(dim))

    # 前向传播函数
    def forward(self, x):
        # 计算输入 x 的 L2 范数,并进行缩放
        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
        # 对输入 x 进行归一化处理
        normed_x = x / norm.clamp(min = self.eps) * self.g
        # 返回经过门控的 RMS 归一化结果
        return normed_x * (x * self.w).sigmoid()

# 定义 FeedForward 函数,返回一个包含线性层和 GELU 激活函数的序列
def FeedForward(dim, mult = 4):
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Linear(dim * mult, dim)
    )

# 定义 ReLA 类,继承自 nn.Module
class ReLA(nn.Module):
    def __init__(
        self,
        *,
        dim,
        causal = True,
        dim_head = 64,
        heads = 8,
        num_memory_kv = 0,
        relu_squared = False
    ):
        super().__init__()
        # 初始化头数和内部维度
        self.heads = heads
        inner_dim = dim_head * heads
        # 初始化缩放因子 scale
        self.scale = dim_head ** -0.5
        # 初始化是否是因果关系
        self.causal = causal
        # 初始化是否对激活函数进行平方操作
        self.relu_squared = relu_squared
        # 初始化 RMS 归一化层
        self.norm = GatedRMSNorm(dim)

        # 初始化 q、k、v 的线性层
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        # 初始化记忆键值对
        self.mem_k = nn.Parameter(torch.randn(num_memory_kv, inner_dim))
        self.mem_v = nn.Parameter(torch.randn(num_memory_kv, inner_dim))

        # 初始化值的 RMS 归一化层和输出层
        self.norm_values = GatedRMSNorm(dim_head)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
        )

    # 前向传播函数
    def forward(self, x, mask = None):
        # 获取输入 x 的批量大小和设备信息
        b, device = x.shape[0], x.device
        # 对输入 x 进行 RMS 归一化处理
        x = self.norm(x)
        h = self.heads

        # 将输入 x 经过 qkv 线性层并分块
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # 将记忆键值对进行扩展并拼接到 k、v 中
        mem_k, mem_v = map(lambda t: repeat(t, 'n d -> b n d', b = b), (self.mem_k, self.mem_v))
        k = torch.cat((mem_k, k), dim = 1)
        v = torch.cat((mem_v, v), dim = 1)

        # 重排 q、k、v 的维度
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        # 对 q 进行缩放
        q = q * self.scale
        # 计算注意力分数
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # 计算注意力值,并进行 ReLU 激活
        attn = F.relu(sim)

        # 如果设置了 relu_squared 标志,则对注意力值进行平方操作
        if self.relu_squared:
            attn = attn ** 2

        # 如果存在 mask,则进行 mask 操作
        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            attn = attn.masked_fill(~mask, 0.)

        # 如果是因果关系,进行因果 mask 操作
        if self.causal:
            i, j = attn.shape[-2:]
            causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
            attn = attn.masked_fill(causal_mask, 0.)

        # 计算输出
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = self.norm_values(out)

        # 重排输出维度
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# 定义 ReLATransformer 类,继承自 nn.Module
class ReLATransformer(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        max_seq_len,
        causal = True,
        heads = 8,
        dim_head = 64,
        num_memory_kv = 0,
        no_ff = False,
        ff_mult = 4,
        relu_squared = False
    ):
        super().__init__()
        # 初始化最大序列长度、token 词嵌入和位置嵌入
        self.max_seq_len = max_seq_len
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        # 初始化层列表
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                ReLA(dim = dim, relu_squared = relu_squared, heads = heads, dim_head = dim_head, num_memory_kv = num_memory_kv, causal = causal),
                FeedForward(dim = dim, mult = ff_mult) if not no_ff else None
            ]))

        # 初始化输出层
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        )
    # 定义前向传播函数,接受输入张量 x 和掩码 mask,默认为 None
    def forward(self, x, mask = None):
        # 获取输入张量 x 的维度 n 和设备信息
        n, device = x.shape[1], x.device
        # 对输入张量 x 进行 token embedding
        x = self.token_emb(x)
        # 根据输入张量 x 的长度 n,生成位置编码 pos_emb
        pos_emb = self.pos_emb(torch.arange(n, device = device))
        # 将位置编码与 token embedding 相加
        x = x + rearrange(pos_emb, 'n d -> 1 n d')

        # 遍历每个注意力层和前馈层
        for attn, ff in self.layers:
            # 使用注意力层处理输入张量 x,并将结果与原始输入相加
            x = attn(x, mask = mask) + x

            # 如果前馈层存在
            if exists(ff):
                # 使用前馈层处理输入张量 x,并将结果与原始输入相加
                x = ff(x) + x

        # 将处理后的张量 x 转换为最终的输出 logits
        return self.to_logits(x)

.\lucidrains\rela-transformer\rela_transformer\__init__.py

# 从 rela_transformer.rela_transformer 模块中导入 ReLATransformer 类
from rela_transformer.rela_transformer import ReLATransformer

.\lucidrains\rela-transformer\setup.py

# 导入设置和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'rela-transformer',  # 包的名称
  packages = find_packages(exclude=[]),  # 查找所有包
  version = '0.0.7',  # 版本号
  license='MIT',  # 许可证
  description = 'ReLA Transformer',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/rela-transformer',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'deep learning',
    'transformers',
    'attention-mechanism',
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[  # 分类器列表
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\rela-transformer\train.py

# 导入所需的模块
from rela_transformer import ReLATransformer
from rela_transformer.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# 常量定义
NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 3e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 512

# 辅助函数

# 从 token 解码为字符
def decode_token(token):
    return str(chr(max(32, token)))

# 从 tokens 解码为字符串
def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# 实例化模型

# 创建 ReLATransformer 模型
model = ReLATransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    max_seq_len = SEQ_LEN,
    heads = 8,
    causal = True
)

# 将模型包装在 AutoregressiveWrapper 中
model = AutoregressiveWrapper(model)
# 将模型移动到 GPU 上
model.cuda()

# 准备 enwik8 数据

# 读取 enwik8 数据集
with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

# 创建自定义数据集类
class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

# 创建训练集和验证集的 DataLoader
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# 优化器

# 创建 Adam 优化器
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 训练

# 循环训练指定次数
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    # 梯度累积
    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    print(f'training loss: {loss.item()}')
    # 梯度裁剪
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        inp = inp[:SEQ_LEN]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp[None, :], GENERATE_LENGTH)
        output_str = decode_tokens(sample.squeeze(0))
        print(output_str)

Remixer - Pytorch

Implementation of the Remixer Block from the Remixer paper, in Pytorch. It claims that substituting the feedforwards in transformers with sequence wide mixing followed by multiplication and subtraction leads to better language understanding results.

Install

$ pip install remixer-pytorch

Usage

import torch
from remixer_pytorch import RemixerBlock

block = RemixerBlock(
    dim = 512,
    seq_len = 1024
)

x = torch.randn(1, 1024, 512)
block(x) # (1, 1024, 512)

Citations

@inproceedings{anonymous,
    title   = {Remixers: A Mixer-Transformer Architecture with Compositional Operators for Natural Language Understanding },
    author  = {Anonymous},
    year = {2021},
    url = {https://openreview.net/forum?id=9FHQHJnRtfL}
}

.\lucidrains\remixer-pytorch\remixer_pytorch\remixer_pytorch.py

# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch 中导入 nn、einsum 模块
from torch import nn, einsum
# 从 einops 中导入 rearrange 函数
from einops import rearrange

# 定义 RemixerBlock 类,继承自 nn.Module
class RemixerBlock(nn.Module):
    # 初始化函数,接受 dim、seq_len、causal 和 bias 四个参数
    def __init__(
        self,
        dim,
        seq_len,
        causal = False,
        bias = False
    ):
        super().__init__()
        # 初始化 causal 属性
        self.causal = causal
        # 初始化 proj_in 属性为 Linear 层,输入维度为 dim,输出维度为 2 * dim
        self.proj_in = nn.Linear(dim, 2 * dim, bias = bias)
        # 初始化 mixer 属性为 nn.Parameter,值为随机生成的 seq_len x seq_len 的张量
        self.mixer = nn.Parameter(torch.randn(seq_len, seq_len))
        # 初始化 alpha 属性为 nn.Parameter,值为 0 的张量
        self.alpha = nn.Parameter(torch.tensor(0.))
        # 初始化 proj_out 属性为 Linear 层,输入维度为 dim,输出维度为 dim
        self.proj_out = nn.Linear(dim, dim, bias = bias)

    # 前向传播函数,接受输入 x
    def forward(self, x):
        # 获取 mixer、causal 和 device 属性
        mixer, causal, device = self.mixer, self.causal, x.device
        # 将输入 x 经过 proj_in 层并分割成两部分,x 和 gate
        x, gate = self.proj_in(x).chunk(2, dim = -1)
        # 对 gate 部分进行 gelu 激活函数处理,再与 x 相乘
        x = F.gelu(gate) * x

        # 如果 causal 为 True
        if self.causal:
            # 获取序列长度 seq
            seq = x.shape[1]
            # 创建 mask_value 为 x 数据类型的最小值
            mask_value = -torch.finfo(x.dtype).max
            # 创建上三角矩阵 mask,大小为 (seq, seq)
            mask = torch.ones((seq, seq), device = device, dtype=torch.bool).triu(1)
            # 限制 mixer 的大小为 (seq, seq),并根据 mask 进行填充
            mixer = mixer[:seq, :seq]
            mixer = mixer.masked_fill(mask, mask_value)

        # 对 mixer 进行 softmax 处理
        mixer = mixer.softmax(dim = -1)
        # 使用 einsum 进行矩阵乘法,得到 mixed
        mixed = einsum('b n d, m n -> b m d', x, mixer)

        # 获取 alpha,并进行 sigmoid 处理
        alpha = self.alpha.sigmoid()
        # 计算输出 out,根据 alpha 对 x 和 mixed 进行加权平均
        out = (x * mixed) * alpha + (x - mixed) * (1 - alpha)

        # 将 out 经过 proj_out 层得到最终输出
        return self.proj_out(out)

.\lucidrains\remixer-pytorch\remixer_pytorch\__init__.py

# 从 remixer_pytorch.remixer_pytorch 模块中导入 RemixerBlock 类
from remixer_pytorch.remixer_pytorch import RemixerBlock

.\lucidrains\remixer-pytorch\setup.py

# 导入设置工具和查找包工具
from setuptools import setup, find_packages

# 设置包的元信息
setup(
  # 包名
  name = 'remixer-pytorch',
  # 查找所有包,不排除任何包
  packages = find_packages(exclude=[]),
  # 版本号
  version = '0.0.3',
  # 许可证
  license='MIT',
  # 描述
  description = 'Remixer - Pytorch',
  # 作者
  author = 'Phil Wang',
  # 作者邮箱
  author_email = 'lucidrains@gmail.com',
  # 项目链接
  url = 'https://github.com/lucidrains/remixer-pytorch',
  # 关键词列表
  keywords = [
    'artificial intelligence',
    'transformer',
    'feedforward',
    'mlp-mixer'
  ],
  # 安装依赖
  install_requires=[
    'einops>=0.3',
    'torch>=1.6'
  ],
  # 分类器列表
  classifiers=[
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值