如何制作其他尺寸的CIFAR10(附源码)

如何制作CIFAR10请见本人上一篇博客

手把手教你制作自己的CIFAR数据集(附项目源码)

上一篇文章代码中只能生成32*32的CIFAR,若修改代码中的Shape进行训练,经本人测试,并不能正常做出其他尺寸的数据集,故立此贴帮助大家学习!

  1. 首先修改demo.py文件中第15行中的shape=32shape=你想要的输入大小(如224)

    if __name__ == '__main__':
      data, label, lst = read_data(file_list, data_path, shape=224)
      pickled(save_path, data, label, lst, bin_num = 5)#bin_num为生成的batch数量
    
  2. 修改load_data.py文件中10-12行的超参数

    DATA_LEN = 150528 #数据长度=通道数*图像宽*图像高  150528 = 3*224*224
    CHANNEL_LEN = 50176 #通道长度=数据长度/通道数 50176 = 150528/3
    SHAPE = 224 #圖像大小
    
  3. 修改edit_mate.py文件中的11行将 ‘num_vis’:3072 修改为150528(数据长度)

    dictCow = {'num_cases_per_batch':3139,#每个batch包含的样本数量
           'label_names':['1','10','2','3','4','5','6','7','8','9'],#类别索引,将类别索引表(object_list.txt)中的label_names:填进去
           'num_vis':150528}#将此处修改为你的 数据长度=通道数*图像宽*图像高
    
  4. 修改框架中cifar.py文件中的87行

    cifar.py 在python文件夹->lib ->python3.6 ->site-packages ->torchvision ->datasets ->cifar.py
    ViT为你的Python所在位置
    打开lib找到python
    找到site-packages
    cifar.py在此!

    修改

     self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)#将此处的32改为你的图像大小224
    

    修改为

     self.data = np.vstack(self.data).reshape(-1, 3, 224, 224)#CIFAR利用这里计算iteration次数
    
  5. 重新制作CIFAR10

  • 8
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 24
    评论
以下是torchvision.datasets.CIFAR10源码: ``` import torch.utils.data as data from PIL import Image import os import os.path import numpy as np import pickle class CIFAR10(data.Dataset): """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset. Args: root (string): Root directory of dataset where directory ``cifar-10-batches-py`` exists or will be downloaded to if download is set to True. train (bool, optional): If True, creates dataset from training set, otherwise creates from test set. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. Returns: tuple: (image, target) where target is index of the target class. """ base_folder = 'cifar-10-batches-py' url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz" tgz_md5 = 'c58f30108f718f92721af3b95e74349a' train_list = [ ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], ['data_batch_3', '54ebc095f3ab1f03828d0aae7e51cd9d'], ['data_batch_4', '634d18415352ddfa80567beed471001a'], ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], ] test_list = [ ['test_batch', '40351d587109b95175f43aff81a1287e'], ] def __init__(self, root, train=True, transform=None, target_transform=None, download=False): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') if self.train: downloaded_list = self.train_list else: downloaded_list = self.test_list self.data = [] self.targets = [] # now load the picked numpy arrays for file_name, checksum in downloaded_list: file_path = os.path.join(self.root, self.base_folder, file_name) with open(file_path, 'rb') as f: if 'meta' in file_name: data_dict = pickle.load(f, encoding='latin1') self.classes = data_dict['label_names'] else: data_dict = pickle.load(f, encoding='latin1') self.data.append(data_dict['data']) self.targets.extend(data_dict['labels']) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.data) def _check_integrity(self): root = self.root for fentry in (self.train_list + self.test_list): filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not check_integrity(fpath, md5): return False return True def download(self): import tarfile if self._check_integrity(): print('Files already downloaded and verified') return download_url(self.url, self.root, self.filename, self.tgz_md5) # extract file with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: tar.extractall(path=self.root) print('Done!') class CIFAR100(CIFAR10): """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset. This is a subclass of the `CIFAR10` Dataset. """ base_folder = 'cifar-100-python' url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" filename = "cifar-100-python.tar.gz" tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' train_list = [ ['train', '16019d7e3df5f24257cddd939b257f8d'], ] test_list = [ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], ] def __init__(self, root, train=True, transform=None, target_transform=None, download=False): super(CIFAR100, self).__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) ``` 该代码定义了CIFAR10CIFAR100数据集的类,这些数据集是用于图像分类任务的标准数据集之一。每个数据集都有一个训练集和一个测试集,每个图像都有一个标签,表示它所属的类别。 这些类继承自torch.utils.data.Dataset类,并实现了__getitem__和__len__方法。它们还提供了下载和检查数据集完整性的方法。 在初始化阶段,数据集从pickle文件中加载数据和标签,并将其存储在self.data和self.targets中。__getitem__方法返回图像和标签的元组,__len__方法返回数据集中图像的数量。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 24
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值