使用DataLoader输入不定长序列到RNN网络
1.自定义 collate_fn
class subDataset(Dataset.Dataset):
def __init__(self,Data_1,Label):
self.Data_1 = Data_1
self.Label = Label
def __len__(self):
return len(self.Data_1)
def __getitem__(self, item):
data_1 = torch.Tensor(self.Data_1[item],dtype=torch.float32)
label = torch.Tensor(self.Label[item])
return data_1,label
def collate_fn(data):
data.sort(key=lambda x: len(x[0]),reverse=True)
data_length = [len(sq[0]) for sq in data]
x = [i[0] for i in data]
y = [i[1] for i in data]
data = rnn_utils.pad_squence(x, batch_first=True, padding_value=0)
return data, data_length, y
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
2.更改训练过程中读入数据部分
def Train(epoch):
train_loss = 0
utt_net.train()
for batch_idx,(data_1,length, target) in enumerate(train_loader):
length = torch.tensor(np.array(length))
target = torch.tensor(np.array(target))
batch_x_pack = rnn_utils.pack_padded_sequence(data_1, length, batch_first=True)
if args.cuda:
batch_x_pack,target = batch_x_pack.cuda(),target.cuda()
utt_optim.zero_grad()
utt_out = utt_net(batch_x_pack)
target = target.squeeze()
loss = torch.nn.CrossEntropyLoss()(utt_out, target.long())
loss.backward()
utt_optim.step()
train_loss += loss
3.改写模型
class Utterance_net(nn.Module):
def __init__(self, input_size, hidden_size, output_size, args):
super(Utterance_net, self).__init__()
self.hidden_dim = args.hidden_layer
self.num_layers = args.dia_layers
self.dropout = nn.Dropout(args.dropout)
self.bigru = nn.GRU(input_size, self.hidden_dim, dropout=args.dropout,
batch_first=True, num_layers=self.num_layers, bidirectional=True)
self.hidden2label = nn.Linear(self.hidden_dim * 2, output_size)
def forward(self, input):
_pad, _len = rnn_utils.pad_packed_sequence(input,batch_first=True)
input = self.dropout(input)
gru_out, _ = self.bigru(input)
gru_out = torch.transpose(gru_out, 1, 2)
gru_out = F.max_pool1d(gru_out, gru_out.size(2)).squeeze(2)
gru_out = F.tanh(gru_out)
y = self.hidden2label(gru_out)
return y