3.解读evaluate部分
3.1 evaluate.py
首先从main()函数开始。
def main():
'''
logging.basicConfig打印日志时间,当前执行程序名以及日志信息。
加载各种参数。
'''
logging.basicConfig(level=logging.DEBUG,
format="[%(asctime)s %(filename)s] %(message)s")
if len(sys.argv) != 2:
logging.error("Usage: python eval.py params.py")
sys.exit()
params_path = sys.argv[1]
if not os.path.isfile(params_path):
logging.error("no params file found! path: {}".format(params_path))
sys.exit()
config = importlib.import_module(params_path[:-3]).TRAINING_PARAMS
config["batch_size"] *= len(config["parallels"])
# Start training
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, config["parallels"]))
evaluate(config)
此处插入evaluate/params.py函数。
TRAINING_PARAMS = \
{
"model_params": {
"backbone_name": "darknet_53",
"backbone_pretrained": "",
},
"yolo": {
"anchors": [[[116, 90], [156, 198], [373, 326]],
[[30, 61], [62, 45], [59, 119]],
[[10, 13], [16, 30], [33, 23]]],
"classes": 20,
},
"batch_size": 4,
"iou_thres": 0.5,
"val_path": "../data/coco/5k.txt",
"annotation_path": "../data/coco/annotations/instances_val2014.json",
"img_h": 416,
"img_w": 416,
"parallels": [0],
"pretrain_snapshot": "../weights/official_yolov3_weights_pytorch.pth", #此处可以更改为训练网络后保存的模型。
}
def evaluate(config):
is_training = False
# Load and initialize network
net = ModelMain(config, is_training=is_training)
net.train(is_training)
# Set data parallel
net = nn.DataParallel(net)
net = net.cuda()
# Restore pretrain model
if config["pretrain_snapshot"]:
state_dict = torch.load(config["pretrain_snapshot"])
net.load_state_dict(state_dict)
else:
logging.warning("missing pretrain_snapshot!!!")
# YOLO loss with 3 scales
yolo_losses = []
for i in range(3):
yolo_losses.append(YOLOLoss(config["yolo"]["anchors"][i],
config["yolo"]["classes"], (config["img_w"], config["img_h"])))
# DataLoader
dataloader = torch.utils.data.DataLoader(COCODataset(config["val_path"],
(config["img_w"], config["img_h"]),
is_training=False),
batch_size=config["batch_size"],
shuffle=False, num_workers=16, pin_memory=False)
代码运行时yolo_losses