from torch.utils.data import DataLoader, TensorDataset,WeightedRandomSampler
首先我们确保引入了TensorDataset
其次,我们的数据已经被处理成tensor形状,如果是Dataframe需要先转array再转tensor,我这里使用的是入侵检测数据集UNSWNB15,简单的二维数据集,读取的部分就不展示了
X = torch.tensor(np.array(new_train_df),dtype = torch.float32).to(device)
y = torch.tensor(np.array(y_train_re),dtype = torch.long).to(device)
代码中的new_train_df和y_train_re都是Dataframe格式,通过上面的代码处理后,能够进行TensorDataset的封装,如下
X_train,X_test,y_train,y_test = train_test_split(X,y,train_size=0.8,random_state=42)
train_dataset = TensorDataset(X_train,y_train)
test_dataset = TensorDataset(X_test,y_test)
这里可能在train_dataset之后需要输出里面的数据形状或者label,代码如下:
data,label = train_dataset[0]
print(len(train_dataset))
print(data.size())
data,label = test_dataset[0]
print(len(test_dataset))
print(data.size())
输出为:
206138
torch.Size([42])
51535
torch.Size([42])
如果需要查看标签直接print(label)就可以