DataLoader处理不定长数据并使用RNN训练

使用DataLoader输入不定长序列到RNN网络

1.自定义 collate_fn

class subDataset(Dataset.Dataset):
    def __init__(self,Data_1,Label):
        self.Data_1 = Data_1
        self.Label = Label
    def __len__(self):
        return len(self.Data_1)
    def __getitem__(self, item):
        data_1 = torch.Tensor(self.Data_1[item],dtype=torch.float32)
        label = torch.Tensor(self.Label[item])
        return data_1,label

def collate_fn(data):
    data.sort(key=lambda x: len(x[0]),reverse=True)
    data_length = [len(sq[0]) for sq in data]
    x = [i[0] for i in data]
    y = [i[1] for i in data]
    data = rnn_utils.pad_squence(x, batch_first=True, padding_value=0)
    return data, data_length, y

train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)

2.更改训练过程中读入数据部分

def Train(epoch):
    train_loss = 0
    utt_net.train()
    for batch_idx,(data_1,length, target) in enumerate(train_loader):
        length = torch.tensor(np.array(length))
        target = torch.tensor(np.array(target))
        #此处使用pack_padded_sequence读入数据
        batch_x_pack = rnn_utils.pack_padded_sequence(data_1, length, batch_first=True)
        if args.cuda:
            batch_x_pack,target = batch_x_pack.cuda(),target.cuda()
        utt_optim.zero_grad()
        utt_out = utt_net(batch_x_pack)
        target = target.squeeze()
        loss = torch.nn.CrossEntropyLoss()(utt_out, target.long())
        loss.backward()
        utt_optim.step()
        train_loss += loss

3.改写模型

class Utterance_net(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, args):
        super(Utterance_net, self).__init__()
        self.hidden_dim = args.hidden_layer
        self.num_layers = args.dia_layers
        #  dropout
        self.dropout = nn.Dropout(args.dropout)
        # gru
        self.bigru = nn.GRU(input_size, self.hidden_dim, dropout=args.dropout, 
                            batch_first=True, num_layers=self.num_layers, bidirectional=True)
        # linear
        self.hidden2label = nn.Linear(self.hidden_dim * 2, output_size)


    def forward(self, input):
    	#此处使用pad_packed_sequence读入逐条实际长度和逐条数据
        _pad, _len = rnn_utils.pad_packed_sequence(input,batch_first=True)
        input = self.dropout(input)
        # gru
        gru_out, _ = self.bigru(input)
        gru_out = torch.transpose(gru_out, 1, 2)
        # pooling
        gru_out = F.max_pool1d(gru_out, gru_out.size(2)).squeeze(2)
        gru_out = F.tanh(gru_out)
        # linear
        y = self.hidden2label(gru_out)
        return y
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值