自定义数据集类读取本地图片,并能够用来训练模型

我用CVAE-GAN将cifar100的图片进行了部分替换,然后把这些图片全部保存到本地的文件夹中,然后自定义类来重新构建数据集

'/test_change_64/'是我保存图片的文件夹的相对路径

def getNewDataloader(BATCH_SIZE = 128):
    all_imgs_path = os.listdir('./test_change_64/')

    species = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm']

    all_labels = []

    for i in range(len(all_imgs_path)):
        all_imgs_path[i] = './test_change_64/' + all_imgs_path[i]

    #对所有图片路径进行迭代
    for img in all_imgs_path:
        # 区分出每个img,应该属于什么类别
        name_label = img.split('/')[2].split(' ')[0]
        for i, c in enumerate(species):
            if i == int(name_label):
                all_labels.append(i)

    # 对数据进行转换处理
    mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    std = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    # weather_dataset = Mydatasetpro(all_imgs_path, all_labels, transform)
    # wheather_datalodaer = data.DataLoader(
    #                             weather_dataset,
    #                             batch_size=BATCH_SIZE,
    #                             shuffle=True
    # )
    # imgs_batch, labels_batch = next(iter(wheather_datalodaer))
    # plt.figure(figsize=(12, 8))
    # for i, (img, label) in enumerate(zip(imgs_batch, labels_batch)):
    #     img = img.permute(1, 2, 0).numpy()
    #     plt.subplot(2, 3, i+1)
    #     plt.title(id_to_species.get(label.item()))
    # plt.show()

    #划分测试集和训练集
    all_imgs_path = np.array(all_imgs_path)
    all_labels = np.array(all_labels)

    train_imgs = all_imgs_path
    train_labels = all_labels

    train_ds = Mydatasetpro(train_imgs, train_labels, transform) #TrainSet TensorData
    train_loader = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)#TrainSet Labels

    return train_loader

如果需要测试集,则只需要将训练集划分

#划分测试集和训练集
index = np.random.permutation(len(all_imgs_path))

all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]

#80% as train
s = int(len(all_imgs_path)*0.8)
print(s)

train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_imgs_path[s:]

train_ds = Mydatasetpro(train_imgs, train_labels, transform) #TrainSet TensorData
train_loader = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False)#TrainSet Labels

test_ds = Mydatasetpro(test_imgs, test_labels, transform) #TestSet TensorData
test_loader = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)#TestSet

return train_loader,test_loader

但这时候是需要先打乱数据集再进行划分,同时使用DataLoader的shuffle就不用True,改成False,不然会报错的 !

使用自定义数据集训练模型:

train_new_loader = getNewDataloader(batch_size)

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值