YOLOV5代码解析(小白系列一)

yolov5 代码脚本解析

1. train.py

因为本人太菜了,所以此处就通过一行行debug,然后先记录整个运行逻辑,然后后面再一点点增补每个trick的原理。

def main(opt):
	setlogging(RANK)
	if RANK in [-1,0]:
		print(colorstr('train: ')+', '.join(f'{k}={v}' for k,v in vars(opt).items()))
		check_git_status()
		check_requirements(exclude=['thop'])

这里的 RANK是默认为-1
然后就开启记录logging的函数 在general.py

def set_logging(rank=-1,verbose=True):
	logging.basicConfig(
		format="%(message)s",
		level=logging.INFO if (verbose and rank in [-1,0]) else logging.WARN)

其余两个函数是检测是否做好依赖以及当前是否有搭建github仓库,当然我是没有开启的。。。。

def main(opt):
	setlogging(RANK)
	if RANK in [-1,0]:
		print(colorstr('train: ')+', '.join(f'{k}={v}' for k,v in vars(opt).items()))
		check_git_status()
		check_requirements(exclude=['thop'])
	
	#Resume
	wandb_run=check_wandb_resume(opt) #None
	#所以这里的if就不上了
	opt.data,opt.cfg,opt.hyp=check_file(opt.data),check_file(opt.cfg),check_file(opt.hyp)
	

这里的check_file函数是为了判断传入的数据路径是否合法 general.py 通常如果数据路径不对,就会触发else:的分支

def check_file(file):
	file=str(file)
	if Path(file).is_file() or file =='': #exists
		return file
	elif file.startswith(('http:/','https:/')): #download
		url=str(Path(file)).replace(':/','://')
		file = Path(urllib.parse.unquote(file)).name.split('?')[0] #'%2F' ==> '/' 
		print(f'Downloading {url} to {file}...')
		torch.hub.download_url_to_file(url,file)
		assert Path(file).exists() and Path(file).stat().st_size>0,f'File download failed: {url}'
		return file
	else:
		files=glob.glob('./**/'+file,recursive=True) # find file
		assert len(files),f'File not found:{file}'
		assert len(files)==1, f'Multiple files match '{file}',specify exact path:{files}"
		return files[0]
def main(opt):
	setlogging(RANK)
	if RANK in [-1,0]:
		print(colorstr('train: ')+', '.join(f'{k}={v}' for k,v in vars(opt).items()))
		check_git_status()
		check_requirements(exclude=['thop'])
	
	#Resume
	wandb_run=check_wandb_resume(opt) #None
	#所以这里的if就不上了
	opt.data,opt.cfg,opt.hyp=check_file(opt.data),check_file(opt.cfg),check_file(opt.hyp)
	assert len(opt.cfg) or len(opt.weights),'either --cfg or --weights must be specified'
	opt.img_size.extend([opt.img_size[-1]]*(2-len(opt.img_size)))
	opt.name='evolve' if opt.evolve else opt.name
	opt.save_dir=str(increment_path(Path(opt.project)/opt.name, exist_ok=opt.exist_ok | opt.evolve))

这句opt.img_size.extend([opt.img_size[-1]]*(2-len(opt.img_size)))的作用是扩展两个列表,记录train和test的sizes,输出[640,640],迷幻,不知道是啥,我们继续
这句opt.save_dir=str(increment_path(Path(opt.project)/opt.name, exist_ok=opt.exist_ok | opt.evolve)) 定义保存路径,分析一下参数
opt.project:run/train
opt.name:exp
然后Path这个路径库,可以直接通过/拼接
然后这个increment_path其实是个递增函数,就是用于记录每个epoch的模型 general.py

def increment_path(path,exist_ok=False,sep='',mkdir=False):
	path=Path(path)
	if path.exists() and not exist_ok:
		suffix=path.suffix
		path=path.with_suffix('')
		dirs=glob.glob(f"{path}{sep}*") #similar paths
		matches=[re.search(rf"%s{sep}(\d+)"%path.stem,d) for d in dirs]
		i=[int(m.groups()[0]) for m in matches if m] #indices
		n=max(i)+1 if i else 2 # increment number
		path=Path(f"{path}{sep}{n}{suffix}") #update path
	dir=path if path.suffix=='' else path.parant #directory
	if not dir.exists() and mkdir:
		dir.mkkdir(parents=True,exist_ok=True)
	return path

正则化:\d+ 是有多个数字,path.stem是指取出整个路径的最后一个文件夹的名字,如:

test_path = Path(’/Users/xxx/Desktop/project/data/’)
print(test_path.stem)
输出:‘data’

def main(opt):
	setlogging(RANK)
	if RANK in [-1,0]:
		print(colorstr('train: ')+', '.join(f'{k}={v}' for k,v in vars(opt).items()))
		check_git_status()
		check_requirements(exclude=['thop'])
	
	#Resume
	wandb_run=check_wandb_resume(opt) #None
	#所以这里的if就不上了
	opt.data,opt.cfg,opt.hyp=check_file(opt.data),check_file(opt.cfg),check_file(opt.hyp)
	assert len(opt.cfg) or len(opt.weights),'either --cfg or --weights must be specified'
	opt.img_size.extend([opt.img_size[-1]]*(2-len(opt.img_size)))
	opt.name='evolve' if opt.evolve else opt.name
	opt.save_dir=str(increment_path(Path(opt.project)/opt.name, exist_ok=opt.exist_ok | opt.evolve))

	device = select_device(opt.device,batch_size=opt.batch_size)

这里就返回选择设备的函数 select_device torch_utils.py

def select_device(device='',batch_size=None):
	s=f'YOLOv5 {git_describe() or date_modified()} torch {torch.__version__}'
	cpu=device.lower()=='cpu' 
	if cpu:
		os.environ['CUDA_VISIBLE_DEVICES']='-1'
	elif device:
		os.environ['CUDA_VISIBLE_DEVICES']=device
		assert torch.cuda.is_available(),f'CUDA unavailable,invalid device {device} requested'
	
	cuda=not cpu and torch.cuda.is_available()
	if cuda:
		devices=device.split(',') if device else '0'
		n=len(devices)
		if n>1 and batch_size:
			assert batch_size%n==0,f'batch-size {batch_size} not multiple of GPU count {n}'
		space=' '* (len(s)+1)
		for i,d in enumerate(devices):
			p=torch.cuda.get_device_properities(i) #'GetForce GTX 1080 TI'
			s+=f"{'' if i==0 else space}CUDA:{d} ({p.name},{p.total_memory/1024**2}MB\n)"
	else:
		s+='CPU\n'
	logger.info(s.encode().decode('ascii','ignore') if platform.system()=='Windows' else s)
	return torch.device('cuda:0' if cuda else 'cpu')
def main(opt):
	setlogging(RANK)
	if RANK in [-1,0]:
		print(colorstr('train: ')+', '.join(f'{k}={v}' for k,v in vars(opt).items()))
		check_git_status()
		check_requirements(exclude=['thop'])
	
	#Resume
	wandb_run=check_wandb_resume(opt) #None
	#所以这里的if就不上了
	opt.data,opt.cfg,opt.hyp=check_file(opt.data),check_file(opt.cfg),check_file(opt.hyp)
	assert len(opt.cfg) or len(opt.weights),'either --cfg or --weights must be specified'
	opt.img_size.extend([opt.img_size[-1]]*(2-len(opt.img_size)))
	opt.name='evolve' if opt.evolve else opt.name
	opt.save_dir=str(increment_path(Path(opt.project)/opt.name, exist_ok=opt.exist_ok | opt.evolve))

	device = select_device(opt.device,batch_size=opt.batch_size)
	#DDP那部分先忽略。。。。。
	if not opt.evolve:
		train(opt.hyp,opt,device)

此处,opt.hypdata/hyp.scratch.yaml
总结一下,到目前为止,这些代码做了这些事情:
1.开启logging
2.判断导入的数据、模型的yaml的路径是否合法
3.配置保存模型路径以及train\test的输入size
4.判断当前环境具有多少个GPU
5.开启训练

-------分割线---------------------------------------------------------------------------------------------------------
然后进入train函数

def train(hyp,opt,device):
	#解析参数
	save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\
		opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers
	
	#Directories
	save_dir=Path(save_dir)
	wdir=save_dir/'weights'
	wdir.mkdir(parents=True,exist_ok=True)
	last=wdir/'last.pt'
	best=wdir/'best.pt'
	results_file=save_dir/'results.txt'
	

到这里为止,就是解析参数+定义最后模型的保存路径和最优模型保存路径以及结果txt的保存路径

def train(hyp,opt,device):
	#解析参数
	save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\
		opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers
	
	#Directories
	save_dir=Path(save_dir)
	wdir=save_dir/'weights'
	wdir.mkdir(parents=True,exist_ok=True)
	last=wdir/'last.pt'
	best=wdir/'best.pt'
	results_file=save_dir/'results.txt'
	
	#Hyperparameters
	if isinstance(hyp,str):
		with open(hyp) as f:
			hyp=yaml.safe_load(f)
	logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items()))
	
	#save run settings
	with open(save_dir/'hyp.yaml','w') as f:
		yaml.safe_dump(hyp,f,sort_keys=False)
	with open(save_dir/ 'opt.yaml','w') as f:
		yaml.safe_dump(vars(opt),f,sort_keys=False)

此处 yaml.safe_load(f)是加载yaml的标准函数接口, yaml.safe_dump()是将yaml文件序列化
vars(opt) 的作用是把数据类型是Namespace的数据转换为字典的形式。

def train(hyp,opt,device):
	#解析参数
	save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\
		opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers
	
	#Directories
	save_dir=Path(save_dir)
	wdir=save_dir/'weights'
	wdir.mkdir(parents=True,exist_ok=True)
	last=wdir/'last.pt'
	best=wdir/'best.pt'
	results_file=save_dir/'results.txt'
	
	#Hyperparameters
	if isinstance(hyp,str):
		with open(hyp) as f:
			hyp=yaml.safe_load(f)
	logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items()))
	
	#save run settings
	with open(save_dir/'hyp.yaml','w') as f:
		yaml.safe_dump(hyp,f,sort_keys=False)
	with open(save_dir/ 'opt.yaml','w') as f:
		yaml.safe_dump(vars(opt),f,sort_keys=False)
	
	#Configure
	plots=not evolve #create plots
	cuda=device.type!='cpu'
	init_seeds(2+RANK) #RANK=-1
	

此处的init_seeds是初始化随机种子,目的是同一训练策略可复现 general.py

def init_seeds(seed=0):
	random.seed(seed)
	np.random.seed(seed)
	init_torch_seeds(seed)

torch_utils.py

def init_torch_seeds(seed=0):
	torch.manual_seed(seed)
	if seed==0:
		cudnn.benchmark,cudnn.deterministic=False,True
	else:
		cudnn.benchmark,cudnn.deterministic=True,False

cudnn.deterministic=True 可避免随机性
cudnn.benchmark =True 随机模式

def train(hyp,opt,device):
	#解析参数
	save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\
		opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers
	
	#Directories
	save_dir=Path(save_dir)
	wdir=save_dir/'weights'
	wdir.mkdir(parents=True,exist_ok=True)
	last=wdir/'last.pt'
	best=wdir/'best.pt'
	results_file=save_dir/'results.txt'
	
	#Hyperparameters
	if isinstance(hyp,str):
		with open(hyp) as f:
			hyp=yaml.safe_load(f)
	logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items()))
	
	#save run settings
	with open(save_dir/'hyp.yaml','w') as f:
		yaml.safe_dump(hyp,f,sort_keys=False)
	with open(save_dir/ 'opt.yaml','w') as f:
		yaml.safe_dump(vars(opt),f,sort_keys=False)
	
	#Configure
	plots=not evolve #create plots
	cuda=device.type!='cpu'
	init_seeds(2+RANK) #RANK=-1
	
	#导入数据
	with open(data) as f:
		data_dict=yaml.safe_load(f)
	
	#Loggers
	loggers={'wandb':None,'tb':None} #loggers dict
	if RANK in [-1,0]:
		#TensorBoard
		if not evolve:
			prefix=colorstr('tensorboard: ')
			logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/")
			loggers['tb']=SummaryWriter(str(save_dir))
	
	# W&B
	opt.hyp=hyp #add hyperparameters
	run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
	run_id=run_id if opt.resume else None 
	wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict)
	logger['wandb']=wandb_logger.wandb
	if logger['wandb']:
		data_dict=wandb_logger.data_dict
		weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp

SummaryWriter(str(save_dir)) 这里是设置tensorboard的保存位置,然后默认不用wandb,因为不会。。。。。

def train(hyp,opt,device):
	#解析参数
	save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\
		opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers
	
	#Directories
	save_dir=Path(save_dir)
	wdir=save_dir/'weights'
	wdir.mkdir(parents=True,exist_ok=True)
	last=wdir/'last.pt'
	best=wdir/'best.pt'
	results_file=save_dir/'results.txt'
	
	#Hyperparameters
	if isinstance(hyp,str):
		with open(hyp) as f:
			hyp=yaml.safe_load(f)
	logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items()))
	
	#save run settings
	with open(save_dir/'hyp.yaml','w') as f:
		yaml.safe_dump(hyp,f,sort_keys=False)
	with open(save_dir/ 'opt.yaml','w') as f:
		yaml.safe_dump(vars(opt),f,sort_keys=False)
	
	#Configure
	plots=not evolve #create plots
	cuda=device.type!='cpu'
	init_seeds(2+RANK) #RANK=-1
	
	#导入数据
	with open(data) as f:
		data_dict=yaml.safe_load(f)
	
	#Loggers
	loggers={'wandb':None,'tb':None} #loggers dict
	if RANK in [-1,0]:
		#TensorBoard
		if not evolve:
			prefix=colorstr('tensorboard: ')
			logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/")
			loggers['tb']=SummaryWriter(str(save_dir))
	
	# W&B
	opt.hyp=hyp #add hyperparameters
	run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
	run_id=run_id if opt.resume else None 
	wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict)
	logger['wandb']=wandb_logger.wandb
	if logger['wandb']:
		data_dict=wandb_logger.data_dict
		weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp

	nc=1 if single_cls else int(data_dict['nc']) # 类别数量
	names=['item'] if single_cls and len(data_dict['names'])!=1 else data_dict['names']
	assert len(names)==nc, '%g names found for nc=%g dataset in %s'%(len(names),nc,data) #check
	is_coco=data.endswith('coco.yaml') and nc==80

导入类别数,并加了判断当前的数据集是否为coco数据

def train(hyp,opt,device):
	#解析参数
	save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\
		opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers
	
	#Directories
	save_dir=Path(save_dir)
	wdir=save_dir/'weights'
	wdir.mkdir(parents=True,exist_ok=True)
	last=wdir/'last.pt'
	best=wdir/'best.pt'
	results_file=save_dir/'results.txt'
	
	#Hyperparameters
	if isinstance(hyp,str):
		with open(hyp) as f:
			hyp=yaml.safe_load(f)
	logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items()))
	
	#save run settings
	with open(save_dir/'hyp.yaml','w') as f:
		yaml.safe_dump(hyp,f,sort_keys=False)
	with open(save_dir/ 'opt.yaml','w') as f:
		yaml.safe_dump(vars(opt),f,sort_keys=False)
	
	#Configure
	plots=not evolve #create plots
	cuda=device.type!='cpu'
	init_seeds(2+RANK) #RANK=-1
	
	#导入数据
	with open(data) as f:
		data_dict=yaml.safe_load(f)
	
	#Loggers
	loggers={'wandb':None,'tb':None} #loggers dict
	if RANK in [-1,0]:
		#TensorBoard
		if not evolve:
			prefix=colorstr('tensorboard: ')
			logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/")
			loggers['tb']=SummaryWriter(str(save_dir))
	
	# W&B
	opt.hyp=hyp #add hyperparameters
	run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
	run_id=run_id if opt.resume else None 
	wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict)
	logger['wandb']=wandb_logger.wandb
	if logger['wandb']:
		data_dict=wandb_logger.data_dict
		weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp

	nc=1 if single_cls else int(data_dict['nc']) # 类别数量
	names=['item'] if single_cls and len(data_dict['names'])!=1 else data_dict['names']
	assert len(names)==nc, '%g names found for nc=%g dataset in %s'%(len(names),nc,data) #check
	is_coco=data.endswith('coco.yaml') and nc==80
	
	# Model
	pretrained=weights.endswith('.pt')
	if pretrained:
		with torch_distributed_zero_first(RANK):
			weights=attempt_download(weights)
		ckpt=torch.load(weights,map_location=device)
		model=Model(cfg or ckpt['model'].yaml,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) #创建模型
		exclude=['anchors'] if (cfg or hyp.get('anchors')) and not resume else [] #anchors
		state_dict=ckpt['model'].float().state_dict()
		state_dict=intersect_dicts(state_dict,model.state_dict(),exclude=exclude)
		model.load_state_dict(state_dict,strict=False)
		logger.info('Transferred %g/%g items from %s' %(len(state_dict),len(model.state_dict()),weights)) 
	else:
		model=Model(cfg,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device)

其中 intersect_dicts 函数是找出anchors, 但从表面上看好像没有什么特别的
torch_utils.py

def intersect_dicts(da,db,exclude=()):
	return {k:v for k,v in da.items() if k in db and not any(x in k for x in exclude) and v.shape==db[k].shape}
def train(hyp,opt,device):
	#解析参数
	save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\
		opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers
	
	#Directories
	save_dir=Path(save_dir)
	wdir=save_dir/'weights'
	wdir.mkdir(parents=True,exist_ok=True)
	last=wdir/'last.pt'
	best=wdir/'best.pt'
	results_file=save_dir/'results.txt'
	
	#Hyperparameters
	if isinstance(hyp,str):
		with open(hyp) as f:
			hyp=yaml.safe_load(f)
	logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items()))
	
	#save run settings
	with open(save_dir/'hyp.yaml','w') as f:
		yaml.safe_dump(hyp,f,sort_keys=False)
	with open(save_dir/ 'opt.yaml','w') as f:
		yaml.safe_dump(vars(opt),f,sort_keys=False)
	
	#Configure
	plots=not evolve #create plots
	cuda=device.type!='cpu'
	init_seeds(2+RANK) #RANK=-1
	
	#导入数据
	with open(data) as f:
		data_dict=yaml.safe_load(f)
	
	#Loggers
	loggers={'wandb':None,'tb':None} #loggers dict
	if RANK in [-1,0]:
		#TensorBoard
		if not evolve:
			prefix=colorstr('tensorboard: ')
			logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/")
			loggers['tb']=SummaryWriter(str(save_dir))
	
	# W&B
	opt.hyp=hyp #add hyperparameters
	run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
	run_id=run_id if opt.resume else None 
	wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict)
	logger['wandb']=wandb_logger.wandb
	if logger['wandb']:
		data_dict=wandb_logger.data_dict
		weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp

	nc=1 if single_cls else int(data_dict['nc']) # 类别数量
	names=['item'] if single_cls and len(data_dict['names'])!=1 else data_dict['names']
	assert len(names)==nc, '%g names found for nc=%g dataset in %s'%(len(names),nc,data) #check
	is_coco=data.endswith('coco.yaml') and nc==80
	
	# Model
	pretrained=weights.endswith('.pt')
	if pretrained:
		with torch_distributed_zero_first(RANK):
			weights=attempt_download(weights)
		ckpt=torch.load(weights,map_location=device)
		model=Model(cfg or ckpt['model'].yaml,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) #创建模型
		exclude=['anchors'] if (cfg or hyp.get('anchors')) and not resume else [] #anchors
		state_dict=ckpt['model'].float().state_dict()
		state_dict=intersect_dicts(state_dict,model.state_dict(),exclude=exclude)
		model.load_state_dict(state_dict,strict=False)
		logger.info('Transferred %g/%g items from %s' %(len(state_dict),len(model.state_dict()),weights)) 
	else:
		model=Model(cfg,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device)
	
	# 组装 训练数据路径和测试数据路径
	train_path=data_dict['train']
	test_path=data_dict['val']
	
	#动态冻结某层
	freeze=[]
	for k,v in model.named_parameters():
		v.requires_grad=True
		if any(x in k for x in freeze):
			print('freezing %s'%k)
			v.requires_grad=False

总结一下:
train.py 中到目前的工作是做了以下工作:

  1. 解析各种yaml的参数
  2. 初始化随机种子
  3. 导入数据yaml
  4. 导预训练模型
  5. 组装训练数据路径和测试数据路径
    -------分割线---------------------------------------------------------------------------------------------------------
def train(hyp,opt,device):
	#解析参数
	save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\
		opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers
	
	#Directories
	save_dir=Path(save_dir)
	wdir=save_dir/'weights'
	wdir.mkdir(parents=True,exist_ok=True)
	last=wdir/'last.pt'
	best=wdir/'best.pt'
	results_file=save_dir/'results.txt'
	
	#Hyperparameters
	if isinstance(hyp,str):
		with open(hyp) as f:
			hyp=yaml.safe_load(f)
	logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items()))
	
	#save run settings
	with open(save_dir/'hyp.yaml','w') as f:
		yaml.safe_dump(hyp,f,sort_keys=False)
	with open(save_dir/ 'opt.yaml','w') as f:
		yaml.safe_dump(vars(opt),f,sort_keys=False)
	
	#Configure
	plots=not evolve #create plots
	cuda=device.type!='cpu'
	init_seeds(2+RANK) #RANK=-1
	
	#导入数据
	with open(data) as f:
		data_dict=yaml.safe_load(f)
	
	#Loggers
	loggers={'wandb':None,'tb':None} #loggers dict
	if RANK in [-1,0]:
		#TensorBoard
		if not evolve:
			prefix=colorstr('tensorboard: ')
			logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/")
			loggers['tb']=SummaryWriter(str(save_dir))
	
	# W&B
	opt.hyp=hyp #add hyperparameters
	run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
	run_id=run_id if opt.resume else None 
	wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict)
	logger['wandb']=wandb_logger.wandb
	if logger['wandb']:
		data_dict=wandb_logger.data_dict
		weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp

	nc=1 if single_cls else int(data_dict['nc']) # 类别数量
	names=['item'] if single_cls and len(data_dict['names'])!=1 else data_dict['names']
	assert len(names)==nc, '%g names found for nc=%g dataset in %s'%(len(names),nc,data) #check
	is_coco=data.endswith('coco.yaml') and nc==80
	
	# Model
	pretrained=weights.endswith('.pt')
	if pretrained:
		with torch_distributed_zero_first(RANK):
			weights=attempt_download(weights)
		ckpt=torch.load(weights,map_location=device)
		model=Model(cfg or ckpt['model'].yaml,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) #创建模型
		exclude=['anchors'] if (cfg or hyp.get('anchors')) and not resume else [] #anchors
		state_dict=ckpt['model'].float().state_dict()
		state_dict=intersect_dicts(state_dict,model.state_dict(),exclude=exclude)
		model.load_state_dict(state_dict,strict=False)
		logger.info('Transferred %g/%g items from %s' %(len(state_dict),len(model.state_dict()),weights)) 
	else:
		model=Model(cfg,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device)
	
	# 组装 训练数据路径和测试数据路径
	train_path=data_dict['train']
	test_path=data_dict['val']
	
	#动态冻结某层
	freeze=[]
	for k,v in model.named_parameters():
		v.requires_grad=True
		if any(x in k for x in freeze):
			print('freezing %s'%k)
			v.requires_grad=False
	
	#配置优化器参数
	nbs=64 #nominal batch size
	accumulate=max(round(nbs/batch_size),1) #32 batch=2
	hyp['weight_decay']*=batch_size*accumulate/nbs #0.0005
	logger.info(f'Scaled weight_decay={hyp['weight_decay']}')
	
	pg0,pg1,pg2=[],[],[] # optimizer parameter groups
	for k,v in model.named_modules():
		if hasattr(v,'bias') and isinstance(v.bias,nn.Parameter):
			pg2.append(v.bias) #biases
			
		if isinstance(v,nn.BatchNorm2d):
			pg0.append(v.weight) #no decay
		elif hasattr(v,'weight') and isinstance(v.weight,nn.Parameter):
			pg1.append(v.weight) #apply decay
	
	if opt.adam:
		optimizer=optim.Adam(pg0,lr=hyp['lr0'],betas=(hyp['momentum'],0.999)) #adjust beta1 to momentum
	else:
		optimizer=optim.SGD(pg0,lr=hyp['lr0'],momentum=hyp['momentum'],nesterov=True)
	#配置decay和biases 这一步的操作是会在optimizer中的param_groups增加一个字典
	optimizer.add_param_group({'params':pg1,'weight_decay':hyp['weight_decay']})
	optimizer.add_param_group({'params':pg2})
	logger.info('Optimizer groups:%g .bias, %g conv.weight, %g other' %(len(pg2),len(pg1),len(pg0))
	del pg0,pg1,pg2
	
	#配置学习率
	if opt.linear_lr:
		lf=lambda x:(1-x/(epochs-1))*(1.0-hyp['lrf'])+hyp['lrf'] 
	else:#OneCycleLR
		lf=one_cycle(1,hyp['lrf'],epochs) #cosine 1->hyp['lrf']
	scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lf)

此处的one_cycle的代码在general.py

def one_cycle(y1=0.0,y2=1.0,steps=100):
	return lambda x:((1-math.cos(x*math.pi/steps))/2)*(y2-y1)+y1
def train(hyp,opt,device):
	#解析参数
	save_dir,epochs,batch_size,weights,single_cls,evolve,data,cfg,resume,notest,nosave,workers=\
		opt.save_dir,opt.epochs,opt.batch_size,opt.weights,opt.single_cls,opt.evolve,opt.data,opt.cfg,opt.resume,opt.notest,opt.nosave,opt.workers
	
	#Directories
	save_dir=Path(save_dir)
	wdir=save_dir/'weights'
	wdir.mkdir(parents=True,exist_ok=True)
	last=wdir/'last.pt'
	best=wdir/'best.pt'
	results_file=save_dir/'results.txt'
	
	#Hyperparameters
	if isinstance(hyp,str):
		with open(hyp) as f:
			hyp=yaml.safe_load(f)
	logger.info(colorstr('hyperparameters: ')+', '.join(f'{k}={v}' for k,v in hyp.items()))
	
	#save run settings
	with open(save_dir/'hyp.yaml','w') as f:
		yaml.safe_dump(hyp,f,sort_keys=False)
	with open(save_dir/ 'opt.yaml','w') as f:
		yaml.safe_dump(vars(opt),f,sort_keys=False)
	
	#Configure
	plots=not evolve #create plots
	cuda=device.type!='cpu'
	init_seeds(2+RANK) #RANK=-1
	
	#导入数据
	with open(data) as f:
		data_dict=yaml.safe_load(f)
	
	#Loggers
	loggers={'wandb':None,'tb':None} #loggers dict
	if RANK in [-1,0]:
		#TensorBoard
		if not evolve:
			prefix=colorstr('tensorboard: ')
			logger.info(f"{prefix}Start with 'tensorboard --logdir' {opt.project}', view at http://localhost:6006/")
			loggers['tb']=SummaryWriter(str(save_dir))
	
	# W&B
	opt.hyp=hyp #add hyperparameters
	run_id=torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
	run_id=run_id if opt.resume else None 
	wandb_logger=WandbLogger(opt,save_dir.stem,run_id,data_dict)
	logger['wandb']=wandb_logger.wandb
	if logger['wandb']:
		data_dict=wandb_logger.data_dict
		weights,epochs,hyp=opt.weights,opt.epochs,opt.hyp

	nc=1 if single_cls else int(data_dict['nc']) # 类别数量
	names=['item'] if single_cls and len(data_dict['names'])!=1 else data_dict['names']
	assert len(names)==nc, '%g names found for nc=%g dataset in %s'%(len(names),nc,data) #check
	is_coco=data.endswith('coco.yaml') and nc==80
	
	# Model
	pretrained=weights.endswith('.pt')
	if pretrained:
		with torch_distributed_zero_first(RANK):
			weights=attempt_download(weights)
		ckpt=torch.load(weights,map_location=device)
		model=Model(cfg or ckpt['model'].yaml,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device) #创建模型
		exclude=['anchors'] if (cfg or hyp.get('anchors')) and not resume else [] #anchors
		state_dict=ckpt['model'].float().state_dict()
		state_dict=intersect_dicts(state_dict,model.state_dict(),exclude=exclude)
		model.load_state_dict(state_dict,strict=False)
		logger.info('Transferred %g/%g items from %s' %(len(state_dict),len(model.state_dict()),weights)) 
	else:
		model=Model(cfg,cn=3,nc=nc,anchors=hyp.get('anchors')).to(device)
	
	# 组装 训练数据路径和测试数据路径
	train_path=data_dict['train']
	test_path=data_dict['val']
	
	#动态冻结某层
	freeze=[]
	for k,v in model.named_parameters():
		v.requires_grad=True
		if any(x in k for x in freeze):
			print('freezing %s'%k)
			v.requires_grad=False
	
	#配置优化器参数
	nbs=64 #nominal batch size
	accumulate=max(round(nbs/batch_size),1) #32 batch=2
	hyp['weight_decay']*=batch_size*accumulate/nbs #0.0005
	logger.info(f'Scaled weight_decay={hyp['weight_decay']}')
	
	pg0,pg1,pg2=[],[],[] # optimizer parameter groups
	for k,v in model.named_modules():
		if hasattr(v,'bias') and isinstance(v.bias,nn.Parameter):
			pg2.append(v.bias) #biases
			
		if isinstance(v,nn.BatchNorm2d):
			pg0.append(v.weight) #no decay
		elif hasattr(v,'weight') and isinstance(v.weight,nn.Parameter):
			pg1.append(v.weight) #apply decay
	
	if opt.adam:
		optimizer=optim.Adam(pg0,lr=hyp['lr0'],betas=(hyp['momentum'],0.999)) #adjust beta1 to momentum
	else:
		optimizer=optim.SGD(pg0,lr=hyp['lr0'],momentum=hyp['momentum'],nesterov=True)
	#配置decay和biases 这一步的操作是会在optimizer中的param_groups增加一个字典
	optimizer.add_param_group({'params':pg1,'weight_decay':hyp['weight_decay']})
	optimizer.add_param_group({'params':pg2})
	logger.info('Optimizer groups:%g .bias, %g conv.weight, %g other' %(len(pg2),len(pg1),len(pg0))
	del pg0,pg1,pg2
	
	#配置学习率
	if opt.linear_lr:
		lf=lambda x:(1-x/(epochs-1))*(1.0-hyp['lrf'])+hyp['lrf'] 
	else:#OneCycleLR
		lf=one_cycle(1,hyp['lrf'],epochs) #cosine 1->hyp['lrf']
	scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lf)
	
	#EMA
	ema=modelEMA(model) if RANK in [-1,0] else None

这里的EMA 全名:Exponential Moving Average 目的是保持模型参数在一种动态平均的状态。
代码在 torch_utils.py

class ModelEMA:
	def __init__(self,model,decay=0.999,updates=0.):
		#创建EMA
		self.ema=deepcopy(model.module if is_parallel(model) else model).eval() #FP32 ema
		self.updates=updates
		self.decay=lambda x:decay*(1-math.exp(-x/2000))
		for p in self.ema.parameters():
			p.requires_grad_(False)
	
	def update(self,model):
		#更新EMA
		with torch.no_grad():
			self.updates+=1
			d=self.decay(self.updates)
			msd=model.module.state_dict() if is_parallel(model) else model.state_dict() 
			for k,v in self.ema.state_dict().items():
				if v.dtype.is_floating_point:
					v*=d
					v+=(1.-d)*msd[k].detach()
	
	def update_attr(self,model,include=(),exclude=('process_group','reducer')):
		#更新属性
		copy_attr(self.ema,model,include,exclude)

因为篇幅太长,博客太卡了,在此等后续

  • 25
    点赞
  • 82
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值