DataLoader 多模态输入方法
多模态研究中,经常需要多种特征拼接输入网络,例如cnn的输出拼接低维特征后连接线性层进行输出、简单总结DataLoader使用办法:
import torch.utils.data.dataset as Dataset
class subDataset(Dataset.Dataset):
def __init__(self,Feature_1,Feature_2,Label):
self.Feature_1 = Feature_1
self.Feature_2 = Feature_2
self.Label = Label
def __len__(self):
return len(self.Label)
def __getitem__(self,index):
Feature_1 = torch.Tensor(self.Feature_1[index])
Feature_2 = torch.Tensor(self.Feature_2[index])
Label = torch.Tensor(self.Label[index])
return Feature_1,Feature_2,Label
train_dataset = subDataset(Feature_1,Feature_2,Label)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)