导入包:
import argparse
import os
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import numpy as np
from models import API_Net
from datasets import RandomDataset, BatchDataset, BalancedBatchSampler
from utils import accuracy, AverageMeter, save_checkpoint
**1)相关参数:**如果显存不够用的话,就需要调整n_classes,n_samples。
经常用到的参数 | |
---|---|
num_works | 获取批量样本时,相当于线程,有多个途径来提供这一个batch的样本,需要根据cpu核数以及RAM 来设置 |
batch_size | 一批样本的个数 |
epochs | 循环次数 |
start_epoch | 因意外中断,重新启动训练时的epoch |
learning_rate | 学习率 |
momentum | |
weight_decay | |
resume | 是否恢复存在的模型参数 |
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--exp_name', default=None, type=str,
help='name of experiment')
parser.add_argument('--data', metavar='DIR', default='',
help='path to dataset')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', #需要用到
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=150, type=int, metavar='N', #总epoch
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',#意外停止训练后,重启训练开始的epoch
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=6, type=int,
#此处的batch_size是指对验证集进行的设置,训练集的每批样本数通过设置batchsampler就设置了。
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=1, type=int, #打印频率,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--evaluate-freq', default=10, type=int,
help='the evaluation frequence')
parser.add_argument('--resume', default='./model_best.pth.tar', type=str, metavar='PATH', #恢复模型
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action=