Pytorch | yolov3代码详解一
说明:仅供自己学习记录,有参考其他博主,侵删
代码来源:eriklindernoren/PyTorch-YOLOv3
参考链接
参考链接
augmentations.py
import torch
import torch.nn.functional as F
import numpy as np
##########################################################################
#数据增广,主要是翻转
##########################################################################
def horisontal_flip(images, targets):#对图像和标签进行镜像翻转
images = torch.flip(images, [-1])#镜像翻转
targets[:, 2] = 1 - targets[:, 2]
#targets是对应的标签[置信度,中心点高度,中心点宽度,框高度,框宽度]
#镜像翻转时,受影响的只有targets[:, 2], 中心点宽度
return images, targets
logger.py
import tensorflow as tf
##########################################################################
#训练记录
##########################################################################
class Logger(object):
def __init__(self, log_dir):#log_dir是日志的路径
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.FileWriter(log_dir)#创建一个summary writer
#由于版本问题,tf.summary.FileWriter可能会报错,改为tf.compat.v1.summary.FileWriter
def scalar_summary(self, tag, value, step): #记录a scalar variable
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
#由于版本问题,tf.summary.FileWriter可能会报错,改为tf.compat.v1.summary.FileWriter
def list_of_scalars_summary(self, tag_value_pairs, step):#记录scalar variables
"""Log scalar variables."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value) for tag, value in tag_value_pairs])
self.writer.add_summary(summary, step)
#由于版本问题,tf.summary.FileWriter可能会报错,改为tf.compat.v1.summary.FileWriter
#可以查看博客 https://blog.csdn.net/encodets/article/details/54172807
parse_config.py
##########################################################################
#解析输入的参数
##########################################################################
#用于解析网络结构文件:如:config/yolov3.cfg
def parse_model_config(path):
"""Parses the yolo-v3 layer configuration file and returns module definitions"""
file = open(path, 'r')
lines = file.read().split('\n') #按行读取 .cfg 中的内容
lines = [x for x in lines if x and not x.startswith('#')] #去掉 # 开始的行
lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces 去掉每行数据前后的空格
module_defs = []
"""
从前往后遍历每一行内容,遇到方括号[],就新建一个字典,并将整个【】结构的内容存放在该字典中
所有的字典都存放在一个列表中,存放结果(列表)如下:
[
{'type': 'net', 'batch': '16', 'subdivisions': '1', 'width': '416',…………},
{'type': 'convolutional', 'batch_normalize': '1', 'filters': '128', 'size': '3',…………},
{'type': 'shortcut', 'from': '-3', 'activation': 'linear'},
{'type': 'route', 'layers': '-1, 36'},
………… # 省略号
]
"""
for line in lines:
if line.startswith('['): # This marks the start of a new block
module_defs.append({})
module_defs[-1]['type'] = line[1:-1].rstrip()
if module_defs[-1]['type'] == 'convolutional':
module_defs[-1]['batch_normalize'] = 0
"""
【yolo】前面的【convolutional】的 activation=linear,且没有 batch_normalize=1 这行。
在解析时,batch_normalize 默认设置为 0. 如果不是yolo前面的,而是其他的,则后面会更新为1
"""
else:
key, value = line.split("=")
value = value.strip()
module_defs[-1][key.rstrip()] = value.strip()
return module_defs
#用于解析数据集配置文件:如:config/coco.data
def parse_data_config(path):
"""Parses the data configuration file"""
"""
添加 coco.data 中没有的两个值:“gpus”:“0, 1, 2, 3” 和 “num_workers”:“10”
将coco.data 每行中 = 前面的值作为 key, 后面的值作为 value ,存放在 字典中
最终输出端结果(字典)是
{'gpus': '0,1,2,3',
'num_workers': '10',
'classes': '80',
'train': 'data/coco/trainvalno5k.txt',
'valid': 'data/coco/5k.txt',
'names': 'data/coco.names',
'backup': 'backup/',
'eval': 'coco'}
"""
options = dict()
options['gpus'] = '0,1,2,3'
options['num_workers'] = '10'
with open(path, 'r') as fp:
lines = fp.readlines()
for line in lines:
line = line.strip()
if line == '' or line.startswith('#'):
continue
key, value = line.split('=')
options[key.strip()] = value.strip()
return options