from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/dogcat_2/')
dataset.class_to_idx
输出:
{'.ipynb_checkpoints': 0, 'cat': 1, 'dog': 2}
将~/python3.6/site-packages/torchvision/datasets/folder.py的源代码改一下即:
将以下函数进行更改:
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) ]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
更改后为:
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and not d.startswith('.')]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
from torchvision.datasets import ImageFolder
dataset = ImageFolder('data/dogcat_2/')
dataset.class_to_idx
输出:
{'cat': 0, 'dog': 1}