SamOut 参数共享

import torch


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.head_num = heads

        self.hidden = hidden_dim

    def forward(self, input_data, state=None):
        b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size

        out = self.head0(input_data)

        out1 = self.head1(input_data)

        out2 = self.head2(input_data)

        out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])

        out = torch.cummax((out + out1) / h ** 0.5, 2)[0]

        out = out.permute([0, 2, 1, 3])
        out1 = out1.permute([0, 2, 1, 3])
        # out2 = out2.permute([0, 2, 1, 3])
        out = out.reshape([b, s, -1])
        out1 = out1.reshape([b, s, -1])
        # out2 = out2.reshape([b, s, -1])
        # out = self.layer_nor(out)

        # out = (out + out2) * out+out1

        # out3=torch.cummax(out,1)[0]
        out = (out + out2) * out + out1

        # out = self.alpha * out * (out + out2) + (1 - self.alpha) * out1

        return out, state


class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size // 2)
        self.ffn2 = torch.nn.Linear(hidden_size //2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size // 2)

        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = self.ffn1(x)
        x2 = self.relu(self.gate(x))
        xx = x1 * x2
        x = self.ffn2(xx)
        return x


class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(DecoderLayer, self).__init__()

        self.self_attention = MaxState(hidden_size, num_heads)

        self.ffn = FeedForward(hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

        self.alpha = torch.nn.Parameter(torch.tensor(0.5))

    def forward(self, x, state=None, ):
        x1, state = self.self_attention(x, state)
        x = self.layer_norm(self.alpha * self.ffn(x1) + (1 - self.alpha) * x)

        return x, state


class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
        # self.pos = torch.nn.Embedding(1024, hidden_size)

        self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
        # self.head = torch.nn.Linear(hidden_size, voc_size, False)
        #
        # self.down = torch.nn.ModuleList(
        #     [torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])

    def state_forward(self, state, x):
        if state is None:
            state = [None] * len(self.decoder_layers)
        i = 0
        for ii, decoder_layer in enumerate(self.decoder_layers):
            # x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))

            x1, state[i] = decoder_layer(x, state[i])
            x = x1 + x
            i += 1
        return x, state

    # def pos_forward(self, x):
    #     if x.shape[1] >= 1024:
    #         pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
    #         pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos
    #
    #     else:
    #         pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
    #     return pos

    def forward(self, x, state=None):
        x = self.em(x)
        # pos = self.pos_forward(x)
        x, state = self.state_forward(state, x)

        # return self.head(x), state
        return x@self.em.weight.permute([1,0]), state


device = "cuda"
if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net.to(device)
    net(torch.randint(0, 200, [2, 8 * 13]).to(device))
    #

该代码定义了一个基于PyTorch的神经网络模型,其核心是一个自定义的解码器层(DecoderLayer),它使用了多头机制(MaxState)和前馈网络(FeedForward)。整个模型(SamOut)用于处理序列数据,比如自然语言处理中的文本。让我们逐个解析这些组件,并讨论输入输出层参数共享的优势。

组件解析

  1. MaxState:

    • 定义了一个带有多个注意力头的模块,每个头都有自己的线性变换(head0, head1, head2)。
    • 使用了累积最大值函数(cummax)来聚合信息。
    • 输出维度被重新调整为批量大小、序列长度、头数和头尺寸的形式。
  2. FeedForward:

    • 一个标准的前馈网络,包括两个线性层和一个门控机制(通过ReLU激活函数实现),用于在解码器层内部处理信息。
  3. DecoderLayer:

    • 包含一个MaxState实例和一个FeedForward实例。
    • 应用了层归一化(LayerNorm)以稳定训练过程。
    • 使用可学习参数alpha来控制来自前馈网络和输入的加权和。
  4. SamOut:

    • 整合了所有上述组件,形成了完整的模型。
    • 使用嵌入层(Embedding)将词汇表索引转换为密集向量表示。
    • 模型包含多个解码器层,由ModuleList管理。
    • 最终输出通过与嵌入层权重的矩阵乘法得到,而不是使用单独的线性层作为输出层。

输入输出层参数共享的优势

SamOut类中,输出层并没有显式定义为一个线性层,而是直接通过嵌入层的权重转置进行计算(x @ self.em.weight.permute([1, 0]))。这种做法通常被称为“权重绑定”或“参数共享”,它具有以下优势:

  • 减少参数数量:由于不引入新的权重,模型的总参数量减少,这有助于降低过拟合的风险。
  • 加速训练:较少的参数意味着更少的计算资源需求,可以加快训练速度。
  • 一致性:输入层和输出层共享相同的权重,保证了模型在处理输入时学到的特征和生成输出时所依赖的特征之间的一致性。
  • 简化架构:无需额外定义输出层,简化了模型架构。

这种技术特别适用于词汇预测任务,如语言模型或机器翻译,其中输入和输出都在同一个词汇空间中。通过共享嵌入层和输出层的权重,我们可以使模型更加紧凑和高效,同时保持良好的性能。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

东方佑

你的鼓励是我最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值