可以用来对 tensor 进行打包,就好像 python 中的 zip 功能。
是 PyTorch 提供的一个用于将张量(tensor)组合成数据集的工具类。它的作用是将多个张量作为输入,构建一个能够以样本对的形式提供数据的数据集。
通常情况下,我们会使用 TensorDataset
将输入数据(如特征张量)和标签数据(如目标张量)组合成一个数据集。然后,我们可以将这个数据集传递给 PyTorch 的数据加载器(如 DataLoader
)用于训练模型。
使用 TensorDataset
的好处之一是它可以方便地与 PyTorch 的数据加载器一起使用,从而可以利用 PyTorch 提供的各种功能来对数据进行批处理、随机化、并行加载等操作。
例如:
data_set = TensorDataset(
torch.tensor(user, dtype=torch.int64),
torch.tensor(item, dtype=torch.int64),
knowledge_emb,
torch.tensor(score, dtype=torch.float32)
)
当你想要调用其中的一部分数据时,可通过索引来访问 data_set
中的每个样本,然后使用类似列表索引的方式来获取对应的数据。
例如,如果你想获取data_set中第一个样本的用户、题目和得分,可以这样做:
sample = data_set[0]
user = sample[0]
item = sample[1]
knowledge_emb = sample[2]
score = sample[3]
这里的 user
、item
、knowledge_emb
和 score
分别表示第一个样本的用户、题目、知识点编码和得分。data_set[0]中存储了user/item/knowledge_emb/score的首位。