记录学习pytorch的过程,从分类任务做起,就从最常见的cifar10下手,数据可在kaggle下载,具体步骤和代码请参考本文余下内容。在cifar10上能有98%的准确率
1、文件件代码组织目录如下所示:
.
├── data
│ ├── class2idx.json
│ ├── test
│ ├── train
│ │ ├── airplane
│ │ ├── automobile
│ │ ├── bird
│ │ ├── cat
│ │ ├── deer
│ │ ├── dog
│ │ ├── frog
│ │ ├── horse
│ │ ├── ship
│ │ └── truck
│ └── trainLabels.csv
├── log
├── models
│ ├── best_model
│ └── last_model
├── sampleSubmission.csv
└── scripts
├── densenet.py
├── __init__.py
├── predict.py
├── preprocess.py
├── result.csv
└── train_model.py
16 directories, 13 files
2、解压并依据trainLabels.csv,将图片划分到10个文件夹下,方便后面步骤调用ImageFolder进行读取数据,preprocess.py如下所示。
import os
import shutil
from tqdm import tqdm
train_data = '../data/train'
train_labels = '../data/trainLabels.csv'
def split_data(labels):
with open(labels) as f_labels:
for line in tqdm(f_labels.readlines()):
line = line.strip()
if 'id' in line:
continue
im_name,im_cls = line.split(',')
dst_path = os.path.join(train_data,im_cls)
if not os.path.exists(dst_path):
os.makedirs(dst_path)
shutil.move(os.path.join(train_data,im_name+'.png'),dst_path)
3、数据处理完成之后,写训练网络的代码train_model.py,如下所示。未完成多卡训练的功能,使用多卡训练一个epoch后悔报错,因此暂且放弃。代码不够优雅,等有空再去整理。函数参数中只需要一个dataloader就可以,代码中的两个是想做mixup加上去,代码未完成。
import torch
import torchvision
from torchvision import transforms,datasets,models
import os
import numpy as np
import time
import json
from torch import nn,optim
from torch.autograd import Variable
from tqdm import tqdm
def train_model(model, criterion, optimizer, data_loaders1,data_loaders2, num_images,cuda_device=False,finetune=None, num_epochs=25,CUDA_ID=0):
'''
:param model: model for training
:param criterion: loss function
:param optimizer: