mini-imagenet数据处理过程_从头开始训练

mini-imagenet
miniImageNet包含100类共60000张彩色图片,其中每类有600个样本,每张图片的大小被resize到了84×84。这里,这个数据集的训练集和测试集的类别划分为:5:1。相比于CIFAR10数据集,miniImageNet数据集更加复杂,但更适合进行原型设计和实验研究。
https://studio.brainpp.com/dataset/2313?name=mini_imagenet
下载过程
注册账号-启动环境-在dataset文件夹找到mini-imagenet数据集压缩包-将其拖动到workspace文件夹-右击压缩包-点击download即可
数据处理过程

# -*- coding: utf-8 -*-
"""
这部分程序可以将放在一起的图片按照训练集测试集和验证集分开
但是有个问题
因为这个数据集是从imagenet抽取出来的
训练集测试集、测试集和验证集的数据类别都不相同
其类别总数是100类
这意味着我们不能从头开始训练
因此还需要对这个数据集进行重新划分
"""

import csv
import os
from PIL import Image
train_csv_path="./mini_imagenet/train.csv"
val_csv_path="./mini_imagenet/val.csv"
test_csv_path="./mini_imagenet/test.csv"
 
train_label={}
val_label={}
test_label={}
with open(train_csv_path) as csvfile:
    csv_reader=csv.reader(csvfile)
    birth_header=next(csv_reader)
    for row in csv_reader:
        train_label[row[0]]=row[1]
 
with open(val_csv_path) as csvfile:
    csv_reader=csv.reader(csvfile)
    birth_header=next(csv_reader)
    for row in csv_reader:
        val_label[row[0]]=row[1]
 
with open(test_csv_path) as csvfile:
    csv_reader=csv.reader(csvfile)
    birth_header=next(csv_reader)
    for row in csv_reader:
        test_label[row[0]]=row[1]
 
img_path="./mini_imagenet/images"
new_img_path="./mini_imagenet/images_OK"
for png in os.listdir(img_path):
    path = img_path+ '/' + png
    im=Image.open(path)
    if(png in train_label.keys()):
        tmp=train_label[png]
        temp_path=new_img_path+'/train'+'/'+tmp
        if(os.path.exists(temp_path)==False):
            os.makedirs(temp_path)
        t=temp_path+'/'+png
        im.save(t)
        # with open(temp_path, 'wb') as f:
        #     f.write(path)
 
    elif(png in val_label.keys()):
        tmp = val_label[png]
        temp_path = new_img_path + '/val' + '/' + tmp
        if (os.path.exists(temp_path) == False):
            os.makedirs(temp_path)
        t = temp_path + '/' + png
        im.save(t)
 
    elif(png in test_label.keys()):
        tmp = test_label[png]
        temp_path = new_img_path + '/test' + '/' + tmp
        if (os.path.exists(temp_path) == False):
            os.makedirs(temp_path)
        t = temp_path + '/' + png
        im.save(t)

运行完这段程序可以得到如下数据分布

train
	类别1
		image1
		image2
	类别2
		image1
		image2
	……
test
	类别3
		image1
		image2
	类别4
		image1
		image2
	……
test
	类别5
		image1
		image2
	类别6
		image1
		image2
	……

将所有的类别都放到一个data文件夹里
得到如下数据分布

data
	类别1
		image1
		image2
		image3
		……
	类别2
		image1
		image2
		image3
		……
	……

# -*- coding: utf-8 -*-
"""
这段代码会从data数据的每个类别中抽取一定比例(可以自己定,这里是1/6)的数据重新组成一个测试集(注意不是拷贝是剪切)
"""

 
import os
import random
import shutil
 
# source_file:源路径, target_ir:目标路径
def cover_files(source_dir, target_ir):
    for file in os.listdir(source_dir):
        source_file = os.path.join(source_dir, file)
 
        if os.path.isfile(source_file):
            shutil.copy(source_file, target_ir)
 
def ensure_dir_exists(dir_name):
    """Makes sure the folder exists on disk.
  Args:
    dir_name: Path string to the folder we want to create.
  """
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
 
def moveFile(file_dir, save_dir):
    ensure_dir_exists(save_dir)
    path_dir = os.listdir(file_dir)    #取图片的原始路径
    filenumber=len(path_dir)
    rate=0.1667    #自定义抽取图片的比例,比方说100张抽10张,那就是0.1
    picknumber=int(filenumber*rate) #按照rate比例从文件夹中取一定数量图片
    print(picknumber)
    sample = random.sample(path_dir, picknumber)  #随机选取picknumber数量的样本图片
    # print (sample)
    for name in sample:
        shutil.move(file_dir+name, save_dir+name)

def mkdir(path):
    folder = os.path.exists(path)
    if not folder:                   #判断是否存在文件夹如果不存在则创建为文件夹
        os.makedirs(path)            #makedirs 创建文件时如果路径不存在会创建这个路径
        print("---  new folder...  ---")
        print("---  OK  ---")
    else:
        print("---  There is this folder!  ---")
 
if __name__ == '__main__':
    
    path='./mini_imagenet/image_CDD/'
    dirs = os.listdir( path+'data/' )
    for file in dirs:
        file_dir = path+'data/'+file+'/'  #源图片文件夹路径
        print(file_dir)
        save_dir = path+'test/'+file  #移动到新的文件夹路径
        print(save_dir)
        mkdir(save_dir)  #创造文件夹
        save_dir = save_dir+'/'
        moveFile(file_dir, save_dir)

抽取完成的的数据分布如下,可见将一个完整的数据集按照一定比例拆分为了训练集(data)和测试集
记得将data改为train

data
	类别1
		image3
		……
	类别2
		image2
		image3
		……
	……
test
	类别1
		image3
		……
	类别2
		image1
		……
	……

最后使用pytorch导入模型

# -*- coding: utf-8 -*-
"""
保证数据分布如上即可
"""

import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# from wideresnet import WideResNet

BATCH_SIZE = 4
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化处理
 # 需要更多数据预处理,自己查
])
transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化处理
 # 需要更多数据预处理,自己查
])

#读取数据
dataset_train = datasets.ImageFolder('./mini_imagenet/image_CDD/train', transform_train)
dataset_test = datasets.ImageFolder('./mini_imagenet/image_CDD/test', transform_test)
#dataset_val = datasets.ImageFolder('data/val', transform)

# 上面这一段是加载测试集的
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True) # 训练集
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True) # 测试集
#val_loader = torch.utils.data.DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True) # 验证集
# 对应文件夹的label
print(dataset_train.class_to_idx)   # 这是一个字典,可以查看每个标签对应的文件夹,也就是你的类别。
                                    # 训练好模型后输入一张图片测试,比如输出是99,就可以用字典查询找到你的类别名称
print(dataset_test.class_to_idx)
#print(dataset_val.class_to_idx)



评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值