Faster R-CNN PyTorch 项目教程
faster_rcnn_pytorchFaster RCNN with PyTorch项目地址:https://gitcode.com/gh_mirrors/fa/faster_rcnn_pytorch
1. 项目的目录结构及介绍
faster_rcnn_pytorch/
├── data/
│ ├── cache/
│ ├── pretrain_model/
│ ├── results/
│ └── VOCdevkit2007/
├── faster_rcnn/
│ ├── datasets/
│ ├── faster_rcnn/
│ ├── networks/
│ ├── roi_data_layer/
│ ├── utils/
│ └── vis/
├── experiments/
│ ├── cfgs/
│ ├── logs/
│ └── models/
├── lib/
│ ├── datasets/
│ ├── fast_rcnn/
│ ├── networks/
│ ├── roi_data_layer/
│ ├── utils/
│ └── vis/
├── scripts/
│ ├── train_faster_rcnn.py
│ └── test_faster_rcnn.py
├── README.md
└── requirements.txt
目录结构介绍
data/
: 存储数据集、缓存、预训练模型和结果。faster_rcnn/
: 核心代码目录,包含数据集处理、模型定义、网络结构、数据层、工具和可视化。experiments/
: 实验配置、日志和模型存储目录。lib/
: 包含数据集处理、快速RCNN、网络结构、数据层、工具和可视化的库。scripts/
: 训练和测试脚本。README.md
: 项目说明文档。requirements.txt
: 项目依赖文件。
2. 项目的启动文件介绍
训练脚本
# scripts/train_faster_rcnn.py
import _init_paths
from datasets.factory import get_imdb
from faster_rcnn.train import train_net
from lib.fast_rcnn.config import cfg_from_file, cfg_from_list, get_output_dir
from lib.networks.factory import get_network
import argparse
import pprint
import numpy as np
import pdb
import sys
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--dataset', dest='dataset',
help='training dataset',
default='voc_2007_trainval', type=str)
parser.add_argument('--net', dest='net',
help='vgg16, res50, res101, res152',
default='vgg16', type=str)
parser.add_argument('--start_epoch', dest='start_epoch',
help='starting epoch',
default=1, type=int)
parser.add_argument('--epochs', dest='max_epochs',
help='number of epochs to train',
default=20, type=int)
parser.add_argument('--disp_interval', dest='disp_interval',
help='number of iterations to display',
default=100, type=int)
parser.add_argument('--save_interval', dest='save_interval',
help='number of iterations to save',
default=10000, type=int)
parser.add_argument('--save_dir', dest='save_dir',
help='directory to save models', default="models",
type=str)
parser.add_argument('--nw', dest='num_workers',
help='number of worker to load data',
default=0, type=int)
parser.add_argument('--cuda', dest='cuda',
help='whether use CUDA',
action='store_true')
parser.add_argument('--ls', dest='large_scale',
help='whether use large imag scale',
action='store_true')
parser.add_argument('--mGPUs', dest='mGPUs',
help='whether use multiple GPUs',
action='store_true')
parser.add_argument('--bs', dest='batch_size',
help='batch_size',
default=1, type=int)
parser.add_argument('--cag', dest='class_agnostic',
help='whether perform class_agnostic bbox regression',
faster_rcnn_pytorchFaster RCNN with PyTorch项目地址:https://gitcode.com/gh_mirrors/fa/faster_rcnn_pytorch