Pytorch极简入门教程(十八)——自定义输入Dataset类

自定义输入Dataset类

导入必要的模块
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader

import torchvision
import matplotlib.pyplot as plt

import glob
导入数据的图片和标签
all_imgs_path = glob.glob(r"G:\Proiect\Pytorch1.6\pytorch进阶\dataset2\*.jpg")
print("all_imgs_path[:5]:\t", all_imgs_path[:5])

#print(len(weather_dataset))
#print(weather_dataset[567:589])

species = ["cloudy", "rain", "shine", "sunrise"]
species_to_idx = dict((c, i) for i, c in enumerate(species))
#print("species_to_idx:\t", species_to_idx)
idx_to_species = dict((v, k) for k, v in species_to_idx.items())
print("idx_to_species:\t", idx_to_species)

# print(next(iter(weather_dataset)))
all_labels= []

# 提取所有图片的标签
for img in all_imgs_path:
    for i, c in enumerate(species):
        if c in img:
            all_labels.append(i)
#print(all_labels[:5])
#print(all_labels[-5:])
#print(all_labels)
定义transform的形式
transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor()
])
定义数据类
# 必须创建 __getitem__ , __len__, __init__
class Mydataset(data.Dataset):
    def __init__(self, img_paths, labels, transform):
        self.imgs = img_paths
        self.labels = labels
        self.transforms = transform
    def __getitem__(self, index):
        img= self.imgs[index]
        label = self.labels[index]
        # 图片的读取
        pil_img = Image.open(img)  # pillow库里面的 PIL
        # 图片的转换 torch形式
        data = self.transforms(pil_img)
        return data, label
    def __len__(self):
        return len(self.imgs)
对象初始化
# 初始化对象
weather_dataset = Mydataset(all_imgs_path, all_labels, transform)
print("type(weather_dataset):\t", type(weather_dataset))


BATCH_SIZE = 16
weather_dl = DataLoader(
    weather_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True
)

imgs_batch, labels_batch = next(iter(weather_dl))
图片和标签展示
plt.figure(figsize=(12, 8))
# enumerate 用来添加序号
for i, (img, label) in enumerate(zip(imgs_batch[-6:], labels_batch[-6:])):
    # permute(1, 2, 0)转换维度 将1, 2 维度放在前面 0维放在后面
    img = img.permute(1, 2, 0).numpy()
    plt.subplot(2, 3, i+1)
    # 打印出对应的类别名称
    plt.title(idx_to_species.get(label.item()))
    plt.imshow(img)
    plt.show()
划分训练集和测试集
# 返回一个和图片一样多的乱序
index = np.random.permutation(len(all_imgs_path))
all_imgs_path = np.array(all_imgs_path)[index]
print("all_imgs_path[:5]:\t", all_imgs_path[:5])
all_labels = np.array(all_labels)[index]
print("all_labels[:5]:\t", all_labels[:5])

# 训练集长度
s = int(len(all_imgs_path) * 0.8)
print("s:\t", s)

train_img = all_imgs_path[:s]
train_labels = all_labels[:s]
test_img = all_imgs_path[s:]
test_labels = all_labels[s:]

train_ds = Mydataset(train_img, train_labels, transform)
test_ds = Mydataset(test_img, test_labels, transform)

train_dl = data.DataLoader(train_ds,
                           batch_size=16,
                           shuffle=True
                           )
test_dl = data.DataLoader(test_ds,
                          batch_size=16,
                          )
定义自己的Dataset
class New_dataset(data.Dataset):
    def __init__(self, some_dataset):
        self.ds = some_dataset
    def __getitem__(self, index):
        img, label = self.ds[index]
        img = img.permute(1, 2, 0)
        return img, label
    def __len__(self):
        return len(self.ds)

train_new_dataset = New_dataset(train_ds)
img, label = train_new_dataset[2]
print("img.shape()", img.shape)
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值