import torch
import numpy as np
class MaxState(torch.nn.Module):
def __init__(self, hidden_dim, heads, win):
super(MaxState, self).__init__()
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
self.head_size = hidden_dim // heads
self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head_num = heads
self.hidden = hidden_dim
def forward(self, input_data, state=None):
# self.head.to(device)
b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size
out = self.head0(input_data)
# 0版
# out1 = torch.max(torch.concat([1-torch.exp(self.head1(input_data).unsqueeze(-1)),1-torch.exp(out.unsqueeze(-1))], -1), -1)[0]
# 1版
# out1 = torch.min(torch.concat(
# [1-torch.exp(h ** 0.5-self.head1(input_data).unsqueeze(-1)), 1-torch.exp(h ** 0.5-out.unsqueeze(-1))],
# -1), -1)[0]
# 2版 超过12层
out1 = torch.min(torch.concat(
[h ** 0.5 - torch.exp(self.head1(input_data).unsqueeze(-1)), h ** 0.5-torch.exp(h ** 0.5 - out.unsqueeze(-1))],
-1), -1)[0]
#
out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])
out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])
# out1 = self.head1(input_data).reshape([b, s, k, h]).permute([0, 2, 1, 3])
out = torch.cummax(out * (torch.exp(out1)+h**0.5), 2)[0]
out = out.permute([0, 2, 1, 3])
out = out.reshape([b, s, -1])
out = torch.min(torch.concat(
[h **0.5-torch.exp(self.head2(input_data).unsqueeze(-1)), torch.exp(h **0.5-out.unsqueeze(-1))],
-1), -1)[0]
# out = torch.min(torch.concat(
# [(out-torch.exp(self.head2(input_data))).unsqueeze(-1), torch.exp(h ** 0.5-out.unsqueeze(-1))],
# -1), -1)[0]
return out, state
class KAttention(torch.nn.Module):
def __init__(self, hidden_dim, heads):
super(KAttention, self).__init__()
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
self.head_size = hidden_dim // heads
self.q = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.k = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.v = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
# self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head_num = heads
def forward(self, x, state=None):
b, s, h, d = x.shape[0], x.shape[1], self.head_num, self.head_size
q = self.q(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
k = self.k(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
v = self.v(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
qk = (q @ k.permute([0, 1, 3, 2])) / d ** 0.5
mask = torch.triu(torch.ones(s, s).to(device))
qk = torch.where(mask.T == 1, qk, torch.Tensor([-float('inf')]).to(device))
qkv = torch.nn.functional.softmax(qk, -1) @ v
# v + torch.arange(1, 3 * s, 3).reshape([1, 1, -1, 1]).to(device) / s / 3)
qkv = qkv.permute([0, 2, 1, 3]).reshape([b, s, -1])
#
return qkv, state
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size):
super(FeedForward, self).__init__()
self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
self.relu = torch.nn.ReLU()
def forward(self, x):
x1 = self.ffn1(x)
x2 = self.relu(self.gate(x))
x = x1 * x2
x = self.ffn2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super(DecoderLayer, self).__init__()
# self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
self.self_attention = MaxState(hidden_size, num_heads, 8)
# self.self_attention = KAttention(hidden_size, num_heads)
self.ffn = FeedForward(hidden_size)
self.layer_norm = torch.nn.LayerNorm(hidden_size)
def forward(self, x, state=None, seq_len=None):
x1, state = self.self_attention(x, state)
x = self.layer_norm(self.ffn(x1) + x)
return x, state
class SamOut(torch.nn.Module):
def __init__(self, voc_size, hidden_size, num_heads, num_layers):
super(SamOut, self).__init__()
self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
self.pos = torch.nn.Embedding(1024, hidden_size)
self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
self.head = torch.nn.Linear(hidden_size, voc_size, False)
# self.head_state = torch.nn.Linear(hidden_size, num_layers, False)
self.down = torch.nn.ModuleList(
[torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])
def state_forward(self, state, pos, x):
if state is None:
state = [None] * len(self.decoder_layers)
i = 0
for ii, decoder_layer in enumerate(self.decoder_layers):
x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))
x1, state[i] = decoder_layer(x, state[i])
x = x1 + x
i += 1
return x, state
def pos_forward(self, x):
if x.shape[1] >= 1024:
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos
else:
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
return pos
def forward(self, x0):
x0, _ = self.one_forward(x0, state=None)
return x0, _
def one_forward(self, x, state=None, seq_len=None):
x = self.em(x)
pos = self.pos_forward(x)
x, state = self.state_forward(state, pos, x)
return self.head(x), state
device = "cuda"
if __name__ == '__main__':
net = SamOut(235, 256, 16, 4)
net.to(device)
net(torch.randint(0, 200, [2, 8 * 13]).to(device))
#
这段代码定义了一个基于PyTorch的神经网络模型,用于序列到序列的转换任务。以下是代码的主要组成部分和功能概述:
- MaxState类:这是一个自定义的注意力机制层,用于处理序列数据。它包含了多个线性层,用于计算注意力权重,并通过累积最大值的方式来更新状态。
- KAttention类:这是另一个自定义的注意力机制层,实现了基于键值对的注意力机制。
- FeedForward类:这是一个前馈神经网络层,包含两个线性层和一个ReLU激活函数,用于在注意力机制之后处理数据。
- DecoderLayer类:这是一个解码器层,包含一个注意力层和一个前馈神经网络层,并使用层归一化。
- SamOut类:这是整个模型的主体,包含嵌入层、位置编码、多个解码器层和一个输出层。它还负责处理状态前向传播和位置编码前向传播。
- 设备配置:代码最后部分将模型移动到CUDA设备上,以便使用GPU进行加速计算。
- 主函数:在主函数中,创建了一个SamOut实例,并将其应用于一个随机整数矩阵,模拟输入数据。
整体而言,这个模型适用于处理序列数据,如自然语言处理任务中的机器翻译、文本摘要等。通过使用注意力机制和前馈神经网络,模型能够学习输入序列和输出序列之间的复杂关系。