pytorch中通过继承Dataset类来自定义数据集

在将数据输入网络前,一般需要通过DataSet和DataLoader两个类将数据进行整理,但是默认的DataSet类对于无标签数据的支持并不友好,通过查资料发现,可以自己重写DataSet来解决这个问题

class CustomDataSet(Dataset):
    def __init__(self, path): #这一步主要是来读取数据,也可在这一步将x,y划分
        self.data = pd.read_csv(path, index_col=0)
        self.data = np.array(self.data)

    def __getitem__(self, index): #这一步主要是返回x,y,若数据中无y则删去和y相关的代码
        x_data = self.data[index][:-1]
        y_data = self.data[index][-1]
        x_data = torch.tensor(x_data,dtype=torch.float32).unsqueeze(0)
        y_data = torch.tensor(y_data)
        return x_data,y_data

    def __len__(self):  #返回样本个数,不是特征个数,这个返回结果直接影响index的值
        return self.data.shape[0]

通过上述的改写,我们在调用CustomDataSet时只需填写path就可以了,然后将结果放入DataLoader中即可

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch 自定义数据集可以通过继承 `torch.utils.data.Dataset` 类来实现。这个类需要实现两个方法:`__len__` 和 `__getitem__`。 `__len__` 方法返回数据集的长度,即样本数量。`__getitem__` 方法返回数据集一个索引对应的样本。 下面是一个简单的例子,假设我们有一个文件夹 `data`,里面包含若干张图片和对应的标签,我们要把这个数据集PyTorch 加载起来: ```python import os from PIL import Image import torch.utils.data as data class CustomDataset(data.Dataset): def __init__(self, root_dir): self.root_dir = root_dir self.img_list = os.listdir(root_dir) def __len__(self): return len(self.img_list) def __getitem__(self, index): img_path = os.path.join(self.root_dir, self.img_list[index]) img = Image.open(img_path).convert('RGB') label = int(self.img_list[index].split('_')[0]) return img, label ``` 在上面的例子,我们定义了一个 `CustomDataset` 类,它有一个构造函数 `__init__`,接收一个参数 `root_dir` 表示数据集所在的文件夹路径。`__init__` 方法初始化了 `img_list` 属性,里面保存了所有图片文件名。 `__len__` 方法返回了 `img_list` 的长度,即数据集样本的数量。 `__getitem__` 方法接收一个索引 `index`,返回了数据集第 `index` 个样本的图片和标签。具体地,它首先获取了图片文件的路径,然后用 `PIL` 库打开图片并转换成 RGB 模式。最后,它从文件名解析出标签信息,并把图片和标签一起返回。 有了这个自定义数据集类,我们就可以用 PyTorch 的 `DataLoader` 类来加载数据集了。例如: ```python import torch.utils.data as data dataset = CustomDataset('data') dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True) ``` 在上面的例子,我们创建了一个 `CustomDataset` 对象 `dataset`,然后用 `DataLoader` 类来初始化 `dataloader` 对象。`DataLoader` 的第一个参数是数据集对象,第二个参数是批量大小,第三个参数是是否打乱数据集顺序。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值