将cifar10数据集其中一类减少到1/10,总量=45500张图片,变成不平衡数据集

主函数:#测试集和训练集是分开整的

import random
import pickle
import numpy as np
from pickled import *
import os
def load_data_cifar(filename, mode='cifar10'):
    """ load data and labels information from cifar10 and cifar100
    cifar10 keys(): dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
    cifar100 keys(): dict_keys([b'filenames', b'batch_label', b'fine_labels', b'coarse_labels', b'data'])
    """
    with open(filename, 'rb') as f:
        dataset = pickle.load(f, encoding='bytes')
        if mode == 'cifar10':
            data = dataset[b'data']
            labels = dataset[b'labels']
            img_names = dataset[b'filenames']
        elif mode == 'cifar100':
            data = dataset[b'data']
            labels = dataset[b'fine_labels']
            img_names = dataset[b'filenames']
        else:
            print("mode should be in ['cifar10', 'cifar100']")
            return None, None, None

    return data, labels, img_names
def load_cifar10(cifar10_path, mode='train'):
    if mode == "train":
        data_all = np.empty(shape=[0, 3072], dtype=np.uint8)
        labels_all = []
        img_names_all = []
        for i in range(1, 6):
            filename = os.path.join(cifar10_path,
                              'data_batch_' + str(i)).replace('\\', '/')
            print("Loading {}".format(filename))
            data, labels, img_names = load_data_cifar(filename, mode='cifar10')
            data_all = np.vstack((data_all, data))
            labels_all += labels
            img_names_all += img_names
        return data_all, labels_all, img_names_all
    elif mode == "test":
        filename = os.path.join(cifar10_path, 'test_batch').replace('\\', '/')
        print("Loading {}".format(filename))
        return load_data_cifar(filename, mode='cifar10')


if __name__ == "__main__":
    # 修改为你的数据集存放路径
    cifar10_path = "../data/cifar-10-batches-py"

   
    # label_names、num_cases_per_batch、num_vis

    # # # 提取cifar10、cifar100的图片数据、标签、文件名
    # data_cifar10_train, labels_cifar10_train, img_names_cifar10_train = \
    #     load_cifar10(cifar10_path, mode='train')
    data_cifar10_test, labels_cifar10_test, img_names_cifar10_test = \
        load_cifar10(cifar10_path, mode='test')

    
    label0_test=[]
    for i in enumerate(labels_cifar10_test):
        if i[1]==0:
            #所有label=0的位置id
            label0_test.append(i[0])
    print("label0_test.len(5000)",len(label0_test))
    label0_500_test = random.sample(label0_test,900)
    label0_500_test.sort()
    #delete data(index)
    data_cifar10_test_del=np.delete(data_cifar10_test,label0_500_test,axis=0)
    
    for counter, index in enumerate(label0_500_test):
        index = index - counter
        #delete label(index)
        labels_cifar10_test.pop(index)
        #delete name(index)
        img_names_cifar10_test.pop(index)


     print(len(labels_cifar10_test),
           len(img_names_cifar10_test),len(data_cifar10_test_del))
     save_path='data/cifar-10-1-batches-py/test'
     pickled(save_path,data_cifar10_test_del,
           labels_cifar10_test,img_names_cifar10_test)

下面这个是pickled.py文件

import os
import pickle
BIN_COUNTS = 1


def pickled(savepath, data, label, fnames, bin_num=BIN_COUNTS, mode="test"):
    '''
      savepath (str): save path
      data (array): image data, a nx3072 array
      label (list): image label, a list with length n
      fnames (str list): image names, a list with length n
      bin_num (int): save data in several files
      mode (str): {'train', 'test'}
    '''
    assert os.path.isdir(savepath)
    total_num = len(fnames)
    samples_per_bin = total_num / bin_num
    print("samples_per_bin",samples_per_bin)
    assert samples_per_bin > 0
    idx = 0
    for i in range(bin_num):
        start = int(i * samples_per_bin)
        end = int((i + 1) * samples_per_bin)

        if end <= total_num:
            dict = {'data': data[start:end, :],
                    'labels': label[start:end],
                    'filenames': fnames[start:end]}
        else:
            dict = {'data': data[start:, :],
                    'labels': label[start:],
                    'filenames': fnames[start:]}
        if mode == "train":
            dict['batch_label'] = "training batch {} of {}".format(idx, bin_num)
        else:
            dict['batch_label'] = "testing batch {} of {}".format(idx, bin_num)

        with open(os.path.join(savepath, 'data_batch_' + str(idx)), 'wb') as fi:
            pickle.dump(dict, fi)
        idx = idx + 1

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值