数据处理(一)dataloader、datasets

数据处理

数据读取

数据(以mini-ImageNet数据集为例)存放方式不同,dataloader也不同。

(一)train.pickle、val.pickle、test.pickle文件

方法1:
1、首先创建datasets文件夹,下方存放miniImagenet.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import torch
import pickle
import numpy as np


def load_data(file):
    with open(file, 'rb') as fo:
        data = pickle.load(fo)
    return data


def buildLabelIndex(labels):
    label2inds = {}
    for idx, label in enumerate(labels):
        if label not in label2inds:
            label2inds[label] = []
        label2inds[label].append(idx)

    return label2inds


class miniImageNet(object):
    """
    Dataset statistics:
    # 64 * 600 (train) + 16 * 600 (val) + 20 * 600 (test)
    """
    dataset_dir = 'path/your data/miniimagenet/'#存放数据的地址

    def __init__(self, **kwargs):
        super(miniImageNet, self).__init__()
        self.train_dir = os.path.join(self.dataset_dir, 'miniImageNet_category_split_train_phase_train.pickle')
        self.val_dir = os.path.join(self.dataset_dir, 'miniImageNet_category_split_val.pickle')
        self.test_dir = os.path.join(self.dataset_dir, 'miniImageNet_category_split_test.pickle')

        self.train, self.train_labels2inds, self.train_labelIds = self._process_dir(self.train_dir)
        self.val, self.val_labels2inds, self.val_labelIds = self._process_dir(self.val_dir)
        self.test, self.test_labels2inds, self.test_labelIds = self._process_dir(self.test_dir)

        self.num_train_cats = len(self.train_labelIds)
        num_total_cats = len(self.train_labelIds) + len(self.val_labelIds) + len(self.test_labelIds)
        num_total_imgs = len(self.train + self.val + self.test)

        print("=> MiniImageNet")
        print("Dataset statistics:")
        print("  ------------------------------")
        print("  subset   | # cats | # images")
        print("  ------------------------------")
        print("  train    | {:5d} | {:8d}".format(len(self.train_labelIds), len(self.train))) #train类别数量 train总图片数量
        print("  val      | {:5d} | {:8d}".format(len(self.val_labelIds),   len(self.val)))
        print("  test     | {:5d} | {:8d}".format(len(self.test_labelIds),  len(self.test)))
        print("  ------------------------------")
        print("  total    | {:5d} | {:8d}".format(num_total_cats, num_total_imgs))
        print("  ------------------------------")

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.val_dir):
            raise RuntimeError("'{}' is not available".format(self.val_dir))
        if not osp.exists(self.test_dir):
            raise RuntimeError("'{}' is not available".format(self.test_dir))

    def _get_pair(self, data, labels):
        data = np.array(data)
        assert (data.shape[0] == len(labels))
        data_pair = []
        for i in range(data.shape[0]):
            data_pair.append((data[i], labels[i]))
        return data_pair

    def _process_dir(self, file_path):
        dataset = load_data(file_path)
        data = dataset['data']
        #print(data.shape)
        labels = dataset['labels']
        data_pair = self._get_pair(data, labels)
        labels2inds = buildLabelIndex(labels)
        labelIds = sorted(labels2inds.keys())
        return data_pair, labels2inds, labelIds

if __name__ == '__main__':
    miniImageNet()

还需要有个_init_.py文件读取dataset,内容如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .miniImageNet import miniImageNet
from .tieredImageNet import tieredImageNet
from .miniImageNet_load import miniImageNet_load


__imgfewshot_factory = {
        'miniImageNet': miniImageNet,
        'tieredImageNet': tieredImageNet,
}

def get_names():
    return list(__imgfewshot_factory.keys()) 

def init_imgfewshot_dataset(name, **kwargs):
    if name not in list(__imgfewshot_factory.keys()):
        raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, list(__imgfewshot_factory.keys())))
    return __imgfewshot_factory[name](**kwargs)

2、数据(3个pickle文件)存放到’path/your data/miniimagenet/'路径下
3、新建一个dataset_loader文件夹,train_dataloader.py(含小样本处理)内容:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import os
from PIL import Image
import numpy as np
import os.path as osp
import lmdb
import io
import random

import torch
from torch.utils.data import Dataset


def read_image(img_path):
    """Keep reading image until succeed.
    This can avoid IOError incurred by heavy IO process."""
    got_img = False
    if not osp.exists(img_path):
        raise IOError("{} does not exist".format(img_path))
    while not got_img:
        try:
            img = Image.open(img_path).convert('RGB')
            got_img = True
        except IOError:
            print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
            pass
    return img


class FewShotDataset_train(Dataset):
    """Few shot epoish Dataset

    Returns a task (Xtrain, Ytrain, Xtest, Ytest, Ycls) to classify'
        Xtrain: [nKnovel*nExpemplars, c, h, w].
        Ytrain: [nKnovel*nExpemplars].
        Xtest:  [nTestNovel, c, h, w].
        Ytest:  [nTestNovel].
        Ycls: [nTestNovel].
    """

    def __init__(self,
                 dataset, # dataset of [(img_path, cats), ...].
                 labels2inds, # labels of index {(cats: index1, index2, ...)}.
                 labelIds, # train labels [0, 1, 2, 3, ...,].
                 nKnovel=5, # number of novel categories.
                 nExemplars=1, # number of training examples per novel category.
                 nTestNovel=6*5, # number of test examples for all the novel categories.
                 epoch_size=2000, # number of tasks per eooch.
                 transform=None,
                 load=False,
                 **kwargs
                 ):
        
        self.dataset = dataset
        self.labels2inds = labels2inds
        self.labelIds = labelIds
        self.nKnovel = nKnovel
        self.transform = transform

        self.nExemplars = nExemplars
        self.nTestNovel = nTestNovel
        self.epoch_size = epoch_size
        self.load = load

    def __len__(self):
        return self.epoch_size

    def _sample_episode(self):
        """sampels a training epoish indexs.
        Returns:
            Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label)
            Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label)
        """

        Knovel = random.sample(self.labelIds, self.nKnovel)
        nKnovel = len(Knovel)
        assert((self.nTestNovel % nKnovel) == 0)
        nEvalExamplesPerClass = int(self.nTestNovel / nKnovel)

        Tnovel = []
        Exemplars = []
        for Knovel_idx in range(len(Knovel)):
            ids = (nEvalExamplesPerClass + self.nExemplars)
            img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids) 

            imgs_tnovel = img_ids[:nEvalExamplesPerClass]
            imgs_emeplars = img_ids[nEvalExamplesPerClass:]

            Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel]
            Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars]
        assert(len(Tnovel) == self.nTestNovel)
        assert(len(Exemplars) == nKnovel * self.nExemplars)
        random.shuffle(Exemplars)
        random.shuffle(Tnovel)

        return Tnovel, Exemplars

    def _creatExamplesTensorData(self, examples):
        """
        Creats the examples image label tensor data.

        Args:
            examples: a list of 2-element tuples. (sample_index, label).

        Returns:
            images: a tensor [nExemplars, c, h, w]
            labels: a tensor [nExemplars]
            cls: a tensor [nExemplars]
        """

        images = []
        labels = []
        cls = []
        for (img_idx, label) in examples:
            img, ids = self.dataset[img_idx]
            if self.load:
                img = Image.fromarray(img)
            else:
                img = read_image(img)
            if self.transform is not None:
                img = self.transform(img)
            images.append(img)
            labels.append(label)
            cls.append(ids)
        images = torch.stack(images, dim=0)
        labels = torch.LongTensor(labels)
        cls = torch.LongTensor(cls)
        return images, labels, cls


    def __getitem__(self, index):
        Tnovel, Exemplars = self._sample_episode()
        Xt, Yt, Ytc = self._creatExamplesTensorData(Exemplars)
        Xe, Ye, Yec = self._creatExamplesTensorData(Tnovel)
        return Xt, Yt, Xe, Ye, Yec#此处输出数量可以按需修改

新建test_loader.py 内容:

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import os
from PIL import Image
import numpy as np
import os.path as osp
import lmdb
import io
import random

import torch
from torch.utils.data import Dataset


def read_image(img_path):
    """Keep reading image until succeed.
    This can avoid IOError incurred by heavy IO process."""
    got_img = False
    if not osp.exists(img_path):
        raise IOError("{} does not exist".format(img_path))
    while not got_img:
        try:
            img = Image.open(img_path).convert('RGB')
            got_img = True
        except IOError:
            print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
            pass
    return img


class FewShotDataset_test(Dataset):
    """Few shot epoish Dataset

    Returns a task (Xtrain, Ytrain, Xtest, Ytest) to classify'
        Xtrain: [nKnovel*nExpemplars, c, h, w].
        Ytrain: [nKnovel*nExpemplars].
        Xtest:  [nTestNovel, c, h, w].
        Ytest:  [nTestNovel].
    """

    def __init__(self,
                 dataset, # dataset of [(img_path, cats), ...].
                 labels2inds, # labels of index {(cats: index1, index2, ...)}.
                 labelIds, # train labels [0, 1, 2, 3, ...,].
                 nKnovel=5, # number of novel categories.
                 nExemplars=1, # number of training examples per novel category.
                 nTestNovel=2*5, # number of test examples for all the novel categories.
                 epoch_size=2000, # number of tasks per eooch.
                 transform=None,
                 load=True,
                 **kwargs
                 ):
        
        self.dataset = dataset
        self.labels2inds = labels2inds
        self.labelIds = labelIds
        self.nKnovel = nKnovel
        self.transform = transform

        self.nExemplars = nExemplars
        self.nTestNovel = nTestNovel
        self.epoch_size = epoch_size
        self.load = load

        seed = 112
        random.seed(seed)
        np.random.seed(seed)

        self.Epoch_Exemplar = []
        self.Epoch_Tnovel = []
        for i in range(epoch_size):
            Tnovel, Exemplar = self._sample_episode()
            self.Epoch_Exemplar.append(Exemplar)
            self.Epoch_Tnovel.append(Tnovel)

    def __len__(self):
        return self.epoch_size

    def _sample_episode(self):
        """sampels a training epoish indexs.
        Returns:
            Tnovel: a list of length 'nTestNovel' with 2-element tuples. (sample_index, label)
            Exemplars: a list of length 'nKnovel * nExemplars' with 2-element tuples. (sample_index, label)
        """

        Knovel = random.sample(self.labelIds, self.nKnovel)
        nKnovel = len(Knovel)
        assert((self.nTestNovel % nKnovel) == 0)
        nEvalExamplesPerClass = int(self.nTestNovel / nKnovel)

        Tnovel = []
        Exemplars = []
        for Knovel_idx in range(len(Knovel)):
            ids = (nEvalExamplesPerClass + self.nExemplars)
            img_ids = random.sample(self.labels2inds[Knovel[Knovel_idx]], ids) 

            imgs_tnovel = img_ids[:nEvalExamplesPerClass]
            imgs_emeplars = img_ids[nEvalExamplesPerClass:]

            Tnovel += [(img_id, Knovel_idx) for img_id in imgs_tnovel]
            Exemplars += [(img_id, Knovel_idx) for img_id in imgs_emeplars]
        assert(len(Tnovel) == self.nTestNovel)
        assert(len(Exemplars) == nKnovel * self.nExemplars)
        random.shuffle(Exemplars)
        random.shuffle(Tnovel)

        return Tnovel, Exemplars

    def _creatExamplesTensorData(self, examples):
        """
        Creats the examples image label tensor data.

        Args:
            examples: a list of 2-element tuples. (sample_index, label).

        Returns:
            images: a tensor [nExemplars, c, h, w]
            labels: a tensor [nExemplars]
        """

        images = []
        labels = []
        for (img_idx, label) in examples:
            img = self.dataset[img_idx][0]
            if self.load:
                img = Image.fromarray(img)
            else:
                img = read_image(img)
            if self.transform is not None:
                img = self.transform(img)
            images.append(img)
            labels.append(label)
        images = torch.stack(images, dim=0)
        labels = torch.LongTensor(labels)
        return images, labels

    def __getitem__(self, index):
        Tnovel = self.Epoch_Tnovel[index]
        Exemplars = self.Epoch_Exemplar[index]
        Xt, Yt = self._creatExamplesTensorData(Exemplars)
        Xe, Ye = self._creatExamplesTensorData(Tnovel)
        return Xt, Yt, Xe, Ye

建立一个新的_init_.py文件,用于调用train/test_dataloader.py:

from __future__ import absolute_import

from .train_loader import FewShotDataset_train
from .test_loader import FewShotDataset_test


__loader_factory = {
        'train_loader': FewShotDataset_train,
        'test_loader': FewShotDataset_test,
}



def get_names():
    return list(__loader_factory.keys()) 


def init_loader(name, *args, **kwargs):
    if name not in list(__loader_factory.keys()):
        raise KeyError("Unknown model: {}".format(name))
    return __loader_factory[name](*args, **kwargs)


以上为读取pickle格式数据。
方法2:
1、miniImagenet.py内容(包含小样本)

import os
import pickle
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class miniImagenet(Dataset):
    def __init__(self, args, partition='train', pretrain=False, is_sample=False, k=4096,
                 transform=None):
        super(Dataset, self).__init__()
        self.data_root = args.data_root
        self.partition = partition
        self.data_aug = args.data_aug
        self.mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0]
        self.std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0]
        self.normalize = transforms.Normalize(mean=self.mean, std=self.std)
        self.pretrain = pretrain

        if transform is None:
            if self.partition == 'train' and self.data_aug:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.RandomCrop(84, padding=8),
                    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                    transforms.RandomHorizontalFlip(),
                    lambda x: np.asarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
            else:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.RandomCrop(84, padding=8),
                    transforms.ToTensor(),
                    self.normalize
                ])
        else:
            self.transform = transform

        if self.pretrain:
            self.file_pattern = 'miniimagenet_category_split_train_phase_%s.pickle'           
        else:
            self.file_pattern = 'miniimagenet_category_split_%s.pickle'           
        self.data = {}
        with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f:
            data = pickle.load(f, encoding='latin1')
            self.imgs = data['data']
            self.labels = data['labels']

        # pre-process for contrastive sampling
        self.k = k
        self.is_sample = is_sample
        if self.is_sample:
            self.labels = np.asarray(self.labels)
            self.labels = self.labels - np.min(self.labels)
            num_classes = np.max(self.labels) + 1

            self.cls_positive = [[] for _ in range(num_classes)]
            for i in range(len(self.imgs)):
                self.cls_positive[self.labels[i]].append(i)

            self.cls_negative = [[] for _ in range(num_classes)]
            for i in range(num_classes):
                for j in range(num_classes):
                    if j == i:
                        continue
                    self.cls_negative[i].extend(self.cls_positive[j])

            self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
            self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]
            self.cls_positive = np.asarray(self.cls_positive)
            self.cls_negative = np.asarray(self.cls_negative)

    def __getitem__(self, item):

        img = np.asarray(self.imgs[item]).astype('uint8')
        img = self.transform(img)
        target = self.labels[item] - min(self.labels)

        if not self.is_sample:
            return img, target, item
        else:
            pos_idx = item
            replace = True if self.k > len(self.cls_negative[target]) else False
            neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
            sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
            return img, target, item, sample_idx
        
    def __len__(self):

        return len(self.labels)


class Metaminiimagenet(miniImagenet):#小样本数据处理
    
    def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True):
        super(Metaminiimagenet, self).__init__(args, partition, False)
        self.fix_seed = fix_seed
        self.n_ways = args.n_ways
        self.n_shots = args.n_shots
        self.n_queries = args.n_queries
        self.classes = list(self.data.keys())
        self.n_test_runs = args.n_test_runs
        self.n_aug_support_samples = args.n_aug_support_samples
        if train_transform is None:
            self.train_transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.RandomCrop(84, padding=8),
                    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                    transforms.RandomHorizontalFlip(),
                    lambda x: np.asarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
        else:
            self.train_transform = train_transform

        if test_transform is None:
            self.test_transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.RandomCrop(84, padding=8),
                    transforms.ToTensor(),
                    self.normalize
                ])
        else:
            self.test_transform = test_transform

        self.data = {}
        #
        self.imgs = np.array(self.imgs)
        #print(self.imgs)
        for idx in range(self.imgs.shape[0]):
            if self.labels[idx] not in self.data:
                self.data[self.labels[idx]] = []
            self.data[self.labels[idx]].append(self.imgs[idx])
        self.classes = list(self.data.keys())

    def __getitem__(self, item):
        if self.fix_seed:
            np.random.seed(item)
        cls_sampled = np.random.choice(self.classes, self.n_ways, False)
        support_xs = []
        support_ys = []
        query_xs = []
        query_ys = []
        for idx, cls in enumerate(cls_sampled):
            imgs = np.asarray(self.data[cls]).astype('uint8')
            support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False)
            support_xs.append(imgs[support_xs_ids_sampled])
            support_ys.append([idx] * self.n_shots)
            query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled)
            query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, True)
            query_xs.append(imgs[query_xs_ids])
            query_ys.append([idx] * query_xs_ids.shape[0])
        support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array(
            query_xs), np.array(query_ys)
        num_ways, n_queries_per_way, height, width, channel = query_xs.shape
        query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel))
        query_ys = query_ys.reshape((num_ways * n_queries_per_way, ))
                
        support_xs = support_xs.reshape((-1, height, width, channel))
        if self.n_aug_support_samples > 1:
            support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1))
            support_ys = np.tile(support_ys.reshape((-1, )), (self.n_aug_support_samples))
        support_xs = np.split(support_xs, support_xs.shape[0], axis=0)
        query_xs = query_xs.reshape((-1, height, width, channel))
        query_xs = np.split(query_xs, query_xs.shape[0], axis=0)
        
        support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs)))
        query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs)))

        return support_xs, support_ys, query_xs, query_ys      
        
    def __len__(self):
        return self.n_test_runs
    
    
if __name__ == '__main__':
    args = lambda x: None
    args.n_ways = 5
    args.n_shots = 1
    args.n_queries = 12
    args.data_root = r'D:/rfs-master/rfs-master/data/miniimagenet/'
    args.data_aug = True
    args.n_test_runs = 5
    args.n_aug_support_samples = 1
    trafficsign = Trafficsign(args, 'val')
    print(len(trafficsign))
    print(trafficsign.__getitem__(500)[0].shape)
    
    metatrafficsign = MetaTrafficsign(args)
    print(len(metatrafficsign))
    print(metatrafficsign.__getitem__(500)[0].size())
    print(metatrafficsign.__getitem__(500)[1].shape)
    print(metatrafficsign.__getitem__(500)[2].size())
    print(metatrafficsign.__getitem__(500)[3].shape)

**pickle文件生成代码:**需要images,train.csv、val.csv、test.csv

import os
import csv
import pickle
from PIL import Image
import numpy as np
#.csv文件存放路径
train_list = r"G:\fewshot test\train.csv"
val_list = r"G:\fewshot test\val.csv"
test_list = r"E:\few-shot\pinggu\GAN-Metrics\code\data\test.csv"
#image存放路径
miniImagenet_Root=r"E:\few-shot\pinggu\GAN-Metrics\code\data\images"

def gen_pickle(csv_file, outfile):
    labelid = []
    labels=[]
    data=[]
    with open(csv_file, )as f:
        f_csv = csv.reader(f)
        classid=0
        i = 0 
        for row in f_csv:
            if i == 0:
                i = i + 1
                continue
            else:
                i = i + 1
            
            labelnum = len(labelid)
            j=0
            for label in labelid:
                if row[1] == label:
                    break
                j+=1
            
            filename = os.path.join(miniImagenet_Root, row[0])
            if j == labelnum:
                labelid.append(row[1])
                print(filename, j, labelid[j])
            
            img = Image.open(filename)
            img = img.resize([84, 84])
            arraydata= np.array(img)
            data.append(arraydata)
            labels.append(j)
            img.close()
            f.close

    dataset={"data":data,"labels":labels}
    with open(outfile , 'wb') as f:
        pickle.dump(dataset, f)
        f.close()

if __name__ == '__main__':
    gen_pickle(test_list, 'trafficsign_category_split_test.pickle')
    gen_pickle(val_list,  'trafficsign_category_split_val.pickle')
    gen_pickle(train_list,  'trafficsign_category_split_train.pickle')

(二)、按照train、val、test存放数据,调用文件main.py内容:

import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#data load
parser = argparse.ArgumentParser(description='Trains ResNet-50 on ImageNet', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data_dir', type = str, default = '',
						help='file where results are to be written')
parser.add_argument('--save_dir', type = str, default = '',
						help='folder where results are to be stored')
parser.add_argument('--mini_imagenet', type = bool, default = False,
						help='Use subset of imagenet for training')
parser.add_argument('--subset', type = int, default = 260,
						help='number of samples from each class. Since there are 1300 samples in each class, 260/1300 is 20/% /of training set')
args = parser.parse_args()

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=mean, std=std)
jittering = util.ColorJitter(brightness=0.4, contrast=0.4,
							  saturation=0.4)
lighting = util.Lighting(alphastd=0.1,
						  eigval=[0.2175, 0.0188, 0.0045],
						  eigvec=[[-0.5675, 0.7192, 0.4009],
								  [-0.5808, -0.0045, -0.8140],
								  [-0.5836, -0.6948, 0.4203]])

transform_train =  transforms.Compose([
		transforms.RandomResizedCrop(224),
		transforms.RandomHorizontalFlip(),
		transforms.ToTensor(),
		jittering,
		lighting,
		normalize,
	])

transform_test = transforms.Compose([
	transforms.Resize(256),
	transforms.CenterCrop(224),
	transforms.ToTensor(),
	transforms.Normalize((0.485, 0.456, 0.406),
						 (0.229, 0.224, 0.225)),
])

train_data = datasets.ImageFolder(root=os.path.join(args.data_dir, 'train'), transform=transform_train)
if args.mini_imagenet:
	# use 20% of the training set. For research who lack resources
	train_data = util.subset_of_ImageNet_train_split(train_data, subset=args.subset)

test_data = datasets.ImageFolder(root=os.path.join(args.data_dir, 'val'), transform=transform_test)

trainloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
testloader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值