下面介绍最核心的部分:网络结构和损失函数
。尤其是损失函数部分,YOLO的损失函数本身并不难理解,但是代码中有很多张量运算及相关函数的使用,使得稍显复杂。
其他相关的部分请见:
YOLO代码解析(1) 代码总览与使用
YOLO代码解析(2) 数据处理
YOLO代码解析(3) 模型和损失函数
YOLO代码解析(4) 训练和测试代码
YOLO论文中的网络结构示意图如下:
网络结构相关代码:yolo_tiny_net.py
这里的网络与YOLO论文中的网络结构稍有不同,不过整体上是一致的
def inference(self, images):
"""构建yolo_tiny网络
输入:
images: 4-D tensor [batch_size, image_height, image_width, channels]
返回:
predicts: 4-D tensor [batch_size, cell_size, cell_size, num_classes + 5 * boxes_per_cell]
"""
conv_num = 1
temp_conv = self.conv2d('conv' + str(conv_num), images, [3, 3, 3, 16], stride=1)
conv_num += 1
temp_pool = self.max_pool(temp_conv, [2, 2], 2)
temp_conv = self.conv2d('conv' + str(conv_num), temp_pool, [3, 3, 16, 32], stride=1)
conv_num += 1
temp_pool = self.max_pool(temp_conv, [2, 2], 2)
temp_conv = self.conv2d('conv' + str(conv_num), temp_pool, [3, 3, 32, 64], stride=1)
conv_num += 1
temp_conv = self.max_pool(temp_conv, [2, 2], 2)
temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 64, 128], stride=1)
conv_num += 1
temp_conv = self.max_pool(temp_conv, [2, 2], 2)
temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 128, 256], stride=1)
conv_num += 1
temp_conv = self.max_pool(temp_conv, [2, 2], 2)
temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 256, 512], stride=1)
conv_num += 1
temp_conv = self.max_pool(temp_conv, [2, 2], 2)
temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 512, 1024], stride=1)
conv_num += 1
temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 1024, 1024], stride=1)
conv_num += 1
temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 1024, 1024], stride=1)
conv_num += 1
temp_conv = tf.transpose(temp_conv, (0, 3, 1, 2)) #(N,H,W,C)=>(N,C,H,W)
# 全链接层
local1 = self.local('local1', temp_conv, self.cell_size * self.cell_size * 1024, 256)
local2 = self.local('local2', local1, 256, 4096)
local3 = self.local('local3', local2, 4096, self.cell_size * self.cell_size * (self.num_classes + self.boxes_per_cell * 5), leaky=False, pretrain=False, train=True)
# 对全连接层输出的tensor进行reshape
# 全连接输出的长度cell_size*cell_size*(num_class+boxes_per_cell*5)二维tensor(还有一个维度是图片数目N)
# YOLO论文中的7*7*(20+5*2)