我用CVAE-GAN将cifar100的图片进行了部分替换,然后把这些图片全部保存到本地的文件夹中,然后自定义类来重新构建数据集
'/test_change_64/'是我保存图片的文件夹的相对路径
def getNewDataloader(BATCH_SIZE = 128):
all_imgs_path = os.listdir('./test_change_64/')
species = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm']
all_labels = []
for i in range(len(all_imgs_path)):
all_imgs_path[i] = './test_change_64/' + all_imgs_path[i]
#对所有图片路径进行迭代
for img in all_imgs_path:
# 区分出每个img,应该属于什么类别
name_label = img.split('/')[2].split(' ')[0]
for i, c in enumerate(species):
if i == int(name_label):
all_labels.append(i)
# 对数据进行转换处理
mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
std = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.Resize(64),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
# weather_dataset = Mydatasetpro(all_imgs_path, all_labels, transform)
# wheather_datalodaer = data.DataLoader(
# weather_dataset,
# batch_size=BATCH_SIZE,
# shuffle=True
# )
# imgs_batch, labels_batch = next(iter(wheather_datalodaer))
# plt.figure(figsize=(12, 8))
# for i, (img, label) in enumerate(zip(imgs_batch, labels_batch)):
# img = img.permute(1, 2, 0).numpy()
# plt.subplot(2, 3, i+1)
# plt.title(id_to_species.get(label.item()))
# plt.show()
#划分测试集和训练集
all_imgs_path = np.array(all_imgs_path)
all_labels = np.array(all_labels)
train_imgs = all_imgs_path
train_labels = all_labels
train_ds = Mydatasetpro(train_imgs, train_labels, transform) #TrainSet TensorData
train_loader = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)#TrainSet Labels
return train_loader
如果需要测试集,则只需要将训练集划分
#划分测试集和训练集
index = np.random.permutation(len(all_imgs_path))
all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]
#80% as train
s = int(len(all_imgs_path)*0.8)
print(s)
train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_imgs_path[s:]
train_ds = Mydatasetpro(train_imgs, train_labels, transform) #TrainSet TensorData
train_loader = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False)#TrainSet Labels
test_ds = Mydatasetpro(test_imgs, test_labels, transform) #TestSet TensorData
test_loader = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)#TestSet
return train_loader,test_loader
但这时候是需要先打乱数据集再进行划分,同时使用DataLoader的shuffle就不用True,改成False,不然会报错的 !
使用自定义数据集训练模型:
train_new_loader = getNewDataloader(batch_size)