ImageNet2012 分类数据集准备工作

"本文档详细介绍了如何准备ImageNet2012分类数据集,包括训练集和验证集的解压及处理。同时,提供了一个使用PyTorch的数据加载模块"data_loader.py",用于高效加载数据集。在训练阶段,调用"data_loader"模块加载数据。"
摘要由CSDN通过智能技术生成

ImageNet2012 分类数据集准备工作

自行下载数据集
ILSVRC2012_img_train.tar
ILSVRC2012_img_test.tar
1.解压训练集
mkdir train
tar xvf ILSVRC2012_img_train.tar -C ./train
touch unzip.sh

#!/bin/bash
dir=./train 
for x in `ls $dir/*tar`
do	
  filename=`basename $x .tar`     
  mkdir $dir/$filename     
  tar -xvf $x -C $dir/$filename 
done 

chmod +x ./unzip.sh
./unzip.sh

2.解压并处理验证集
mkdir val
tar xvf ILSVRC2012_img_val.tar -C ./val
下载处理脚本valprep.sh:
wget https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh
cd val
chmod +x ./valprep.sh
./valprep.sh
rm valprep.sh

数据集的目录树

-imagenet
	-train
	-val

封装数据加载模块

# data_loader.py
import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def data_loader(root, batch_size=256, workers=1, pin_memory=True):
    traindir = os.path.join(root, 'train')
    valdir = os.path.join(root, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ])
    )
    val_dataset = datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
        ])
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=pin_memory,
        sampler=None
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=workers,
        pin_memory=pin_memory
    )
    return train_loader, val_loader

训练调用

# Data loading
from data_loader import data_loader	

def main():
	...
	# args.data = './imagenet'
    train_loader, val_loader = data_loader(args.data, args.batch_size, args.workers, args.pin_memory)
    ...
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值