import os
import cv2
import torchvision
if __name__ == '__main__':
trainset = torchvision.datasets.CIFAR10(root='./datasets/', train=True, download=True)
mp = {}
for i in trainset.classes:
mp[i] = 0
for i in range(len(trainset)):
image = trainset.data[i]
label = trainset.classes[trainset.targets[i]]
fp = './results/%s' % (label)
if not os.path.exists(fp):
os.makedirs(fp)
fp = os.path.join(fp, 'img_%s.jpg' % mp[label])
mp[label] += 1
print(fp)
cv2.imwrite( fp, image)