使用pytorch进行数据cifar10数据分类

这篇博客记录了使用PyTorch进行CIFAR10数据分类的学习过程,包括数据预处理、网络训练、模型应用等步骤。通过解压数据、划分文件夹、编写训练和预测代码,最终在CIFAR10上实现了98%的准确率。虽然遇到多卡训练的问题,但整个流程使作者对PyTorch框架有了更深入的理解。
摘要由CSDN通过智能技术生成

记录学习pytorch的过程,从分类任务做起,就从最常见的cifar10下手,数据可在kaggle下载,具体步骤和代码请参考本文余下内容。在cifar10上能有98%的准确率

1、文件件代码组织目录如下所示:

.
├── data
│   ├── class2idx.json
│   ├── test
│   ├── train
│   │   ├── airplane
│   │   ├── automobile
│   │   ├── bird
│   │   ├── cat
│   │   ├── deer
│   │   ├── dog
│   │   ├── frog
│   │   ├── horse
│   │   ├── ship
│   │   └── truck
│   └── trainLabels.csv
├── log
├── models
│   ├── best_model
│   └── last_model
├── sampleSubmission.csv
└── scripts
    ├── densenet.py
    ├── __init__.py
    ├── predict.py
    ├── preprocess.py
    ├── result.csv
    └── train_model.py

16 directories, 13 files

2、解压并依据trainLabels.csv,将图片划分到10个文件夹下,方便后面步骤调用ImageFolder进行读取数据,preprocess.py如下所示。

import os
import shutil
from tqdm import tqdm

train_data = '../data/train'
train_labels = '../data/trainLabels.csv'

def split_data(labels):
    with open(labels) as f_labels:
        for line in tqdm(f_labels.readlines()):
            line = line.strip()
            if 'id' in line:
                continue
            im_name,im_cls = line.split(',')
            dst_path = os.path.join(train_data,im_cls)
            if not os.path.exists(dst_path):
                os.makedirs(dst_path)
            shutil.move(os.path.join(train_data,im_name+'.png'),dst_path)

 

3、数据处理完成之后,写训练网络的代码train_model.py,如下所示。未完成多卡训练的功能,使用多卡训练一个epoch后悔报错,因此暂且放弃。代码不够优雅,等有空再去整理。函数参数中只需要一个dataloader就可以,代码中的两个是想做mixup加上去,代码未完成。

import torch
import torchvision
from torchvision import transforms,datasets,models
import os
import  numpy as np
import time
import json
from torch import nn,optim
from torch.autograd import Variable
from tqdm import tqdm


def train_model(model, criterion, optimizer, data_loaders1,data_loaders2, num_images,cuda_device=False,finetune=None, num_epochs=25,CUDA_ID=0):
    '''
    :param model: model for training
    :param criterion: loss function
    :param optimizer:
   
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值