API_Net官方代码之训练网络

本文介绍了API_Net的训练过程,包括导入相关包,设置参数,定义模型和损失函数,进行训练并保存模型,同时提供了测试部分的简要说明。
摘要由CSDN通过智能技术生成

导入包:

import argparse
import os
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import numpy as np
from models import API_Net
from datasets import RandomDataset, BatchDataset, BalancedBatchSampler
from utils import accuracy, AverageMeter, save_checkpoint

**1)相关参数:**如果显存不够用的话,就需要调整n_classes,n_samples。

经常用到的参数
num_works 获取批量样本时,相当于线程,有多个途径来提供这一个batch的样本,需要根据cpu核数以及RAM 来设置
batch_size 一批样本的个数
epochs 循环次数
start_epoch 因意外中断,重新启动训练时的epoch
learning_rate 学习率
momentum
weight_decay
resume 是否恢复存在的模型参数
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--exp_name', default=None, type=str,
                    help='name of experiment')
parser.add_argument('--data', metavar='DIR', default='',
                    help='path to dataset')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', #需要用到
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=150, type=int, metavar='N', #总epoch
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',#意外停止训练后,重启训练开始的epoch
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=6, type=int,  
#此处的batch_size是指对验证集进行的设置,训练集的每批样本数通过设置batchsampler就设置了。
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=1, type=int, #打印频率,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--evaluate-freq', default=10, type=int,
                    help='the evaluation frequence')
parser.add_argument('--resume', default='./model_best.pth.tar', type=str, metavar='PATH', #恢复模型
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action=
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值