import torch
import torch.nn as nn
#RNN输入维度 [seq_len,batch_size,input_dim]
x_input=torch.randn(2,3,10)
#torch.nn.RNN(input_size,hidden_size,num_layers,batch_first)
#input_size 输入维度 hidden_size隐藏层维度 num_layers隐藏层层数
#batch_first=True时x:[batch, seq_len, input_size],h0:[batch, num_layers, hidden_size]
#batch_first=False时x:[seq_len, batch, input_size],h0:[num_layers, batch, hidden_size]
class RNN(nn.Module):
def __init__(self,input_size,hidden_size,batch_first=False):
super(RNN,self).__init__()
self.rnn_cell=nn.RNNCell(input_size,hidden_size)
self.batch_first=batch_first
self.hidden_size=hidden_size
#权重初始化 初始隐藏层状态
def _initialize_hidden(self,batch_size):
return torch.zeros((batch_size,self.hidden_size))
#如果初始隐藏状态未设定 则initial_hidden=None,调用_initialize_hidden
def forward(self,inputs,initial_hidden=None):
if self.batch_first:
batch_size,seq_size,feat_size=inputs.size()
#将batch_size换至第二个位置
inputs=inputs.permute(1,0,2)
else:
seq_size,batch_size,feat_size=inputs.size()
#记录每一个隐藏层状态的值
hiddens=[]
if initial_hidden is None:
initial_hidden=self._initialize_hidden(batch_size)
initial_hidden=initial_hidden.to(inputs.device)
#初始化开始时刻的隐藏状态
hidden_t=initial_hidden
for t in range(seq_size):
#当前隐藏状态由当前输入和上一时刻的隐藏状态共同决定
hidden_t=self.rnn_cell(inputs[t],hidden_t)
#将当前状态加入至隐藏状态列表中
hiddens.append(hidden_t)
hiddens=torch.stack(hiddens)
if self.batch_first:
#将batch_size换回至第一个位置,因为在上面已经将batch_size换至第二个位置了,所以此处需要换回来
hiddens=hiddens.permute(1,0,2)
return hiddens
model=RNN(input_size=10,hidden_size=15,batch_first=True)
outputs=model(x_input)
print(outputs)
# print(outputs.shape)
RNN复现
最新推荐文章于 2023-09-09 10:29:10 发布