一.数据集的加载
数据集有官方数据集和自己的数据集两种,对于不同的数据集加载方式有差别,大致如下:
1.对于官方数据集,即可以通过torchvision.datasets后面加点的方式获取数据集名称的这种数据集,如CIFAR10:
torchvision.datasets.CIFAR10(root='./data', train=True,
download=False, transform=transform)
它的训练集和验证集加载可以通过torchvision.datasets.CIFAR10(root=’xxx’,.......) + torch.utils.data.DataLoader()的方式进行加载,代码如下:
# 训练集的设置及其DataLoader
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
shuffle=True, num_workers=0)
# 验证集的设置及其DataLoader
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,
shuffle=False, num_workers=0)
2.对于自己通过建立多个文件夹、每个文件夹放一类图片且文件夹的名称既是该类的名称来得到的自己的数据集,使用torchvision.datasets后面加点的方式是无法访问得到的,这时可以选择torchvision.datasets.ImageFolder来加载数据集,即通过添加数据集文件夹所在绝对路径来进行数据集的加载的。下面代码的路径是我的数据集文件夹train所在的绝对路径,后面的transform是提前设置好的transform。
train_dataset = torchvision.datasets.ImageFolder(root=”F:\data_set\flower_data\train”),
transform=transform)
ImageFolder就是一个通用的data loader,而这个data loader加载数据集的方式就是通过路径,所以ImageFolder代替了torchvision.datasets.CIFAR10这类指定数据集的操作。需要注意的是,给ImageFolder后面的第一个参数root指定根路径时,我的编译器需要写的是图片所在文件夹的绝对路径。完整代码:
# 训练集的设置及其DataLoader
train_dataset = torchvision.datasets.ImageFolder(root=”F:\data_set\flower_data\train”),
transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=num)
# 验证集的设置及其DataLoader
validate_dataset = datasets.ImageFolder(root=”F:\data_set\flower_data\val”),
transform=val_transform)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=4, shuffle=False,
num_workers=nw)
二.关于类别json文件的生成
我的训练集文件夹分类如下图所示:
类别信息生成json文件的代码如下:
# flower_list = {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx #将各类别名称及其文件夹顺序索引做成一个字典
# cla_dict = {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
cla_dict = dict((val, key) for key, val in flower_list.items()) # 交换key和value的位置
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file: # 写json文件
json_file.write(json_str)
json_str = json.dumps(cla_dict, indent=4)将字典写成json字符串以便存入json文件,所存的json文件如下:
这样类别的内容显示更见直观,参数indent表示每一个字典item前面的空格数量,即"0": "daisy"前面空了4个空格,后面的也是如此。