pytorch:
data_loader = DataLoader(
data_set,
batch_size=batch_size,
shuffle=shuffle_flag,
num_workers=args.num_workers,
drop_last=drop_last,
collate_fn=lambda x: collate_fn(x, max_len=args.seq_len)
)
在dataloader的时候使用collate_fn
tensorflow
hist = self.model.fit(x_train, y_train, batch_size=mini_batch_size, epochs=self.nb_epochs,
verbose=self.verbose, validation_data=(x_val, y_val), callbacks=self.callbacks)
输入的就得是numpy了,所以需要单独处理不等长dataframe转numpy
max_seq_len = int(np.max(lengths[:, 0]))
data = np.zeros((df.shape[0], df.shape[1], max_seq_len))
for i in range(df.shape[0]):
length = df.iloc[i].apply(len).values
print('len',length)
for j in range(df.shape[1]):
padded_seq = np.pad(df.iloc[i, j], (0, max_seq_len - length[j]), 'constant', constant_values=np.nan)
data[i, j, :] = padded_seq