1 get_dataset
2 list_dataset.py/ListDataset
from torch.utils.data import Dataset
class ListDataset(Dataset):
def __init__(self, data):
"""
data: 必须是一个 list
"""
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
__init__(self, data)
:构造函数接收一个参数 data
,这个参数必须是一个列表。在这个构造函数中,传入的列表被保存在实例变量 self.data
中。-
__getitem__(self, index)
:这是一个特殊方法,允许类的实例像列表一样通过索引访问。该方法返回 self.data
中指定索引 index
的元素。这是 PyTorch Dataset
接口的一部分,使得每个元素可以通过索引直接访问,非常适合于训练过程中获取单个训练样本。
-
__len__(self)
:这也是一个特殊方法,它返回 self.data
列表的长度,即数据集中的元素总数。这个方法提供了数据集的大小,是 DataLoader
在数据加载过程中进行批处理、采样等操作时需要用到的信息。
3 generate_data
pytorch笔记:Dataloader_tensor如何装进dataloader里-CSDN博客