Tensorflow YOLO代码解析(3)

下面介绍最核心的部分:网络结构和损失函数。尤其是损失函数部分,YOLO的损失函数本身并不难理解,但是代码中有很多张量运算及相关函数的使用,使得稍显复杂。

其他相关的部分请见:
YOLO代码解析(1) 代码总览与使用
YOLO代码解析(2) 数据处理
YOLO代码解析(3) 模型和损失函数
YOLO代码解析(4) 训练和测试代码

YOLO论文中的网络结构示意图如下:
在这里插入图片描述
网络结构相关代码:yolo_tiny_net.py
这里的网络与YOLO论文中的网络结构稍有不同,不过整体上是一致的

def inference(self, images):
    """构建yolo_tiny网络
    输入:
      images:  4-D tensor [batch_size, image_height, image_width, channels]
    返回:
      predicts: 4-D tensor [batch_size, cell_size, cell_size, num_classes + 5 * boxes_per_cell]
    """
    conv_num = 1

    temp_conv = self.conv2d('conv' + str(conv_num), images, [3, 3, 3, 16], stride=1)
    conv_num += 1

    temp_pool = self.max_pool(temp_conv, [2, 2], 2)

    temp_conv = self.conv2d('conv' + str(conv_num), temp_pool, [3, 3, 16, 32], stride=1)
    conv_num += 1

    temp_pool = self.max_pool(temp_conv, [2, 2], 2)

    temp_conv = self.conv2d('conv' + str(conv_num), temp_pool, [3, 3, 32, 64], stride=1)
    conv_num += 1

    temp_conv = self.max_pool(temp_conv, [2, 2], 2)

    temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 64, 128], stride=1)
    conv_num += 1

    temp_conv = self.max_pool(temp_conv, [2, 2], 2)

    temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 128, 256], stride=1)
    conv_num += 1

    temp_conv = self.max_pool(temp_conv, [2, 2], 2)

    temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 256, 512], stride=1)
    conv_num += 1

    temp_conv = self.max_pool(temp_conv, [2, 2], 2)

    temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 512, 1024], stride=1)
    conv_num += 1

    temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 1024, 1024], stride=1)
    conv_num += 1

    temp_conv = self.conv2d('conv' + str(conv_num), temp_conv, [3, 3, 1024, 1024], stride=1)
    conv_num += 1

    temp_conv = tf.transpose(temp_conv, (0, 3, 1, 2)) #(N,H,W,C)=>(N,C,H,W)

    # 全链接层
    local1 = self.local('local1', temp_conv, self.cell_size * self.cell_size * 1024, 256)
    local2 = self.local('local2', local1, 256, 4096)
    local3 = self.local('local3', local2, 4096, self.cell_size * self.cell_size * (self.num_classes + self.boxes_per_cell * 5), leaky=False, pretrain=False, train=True)

    # 对全连接层输出的tensor进行reshape
    # 全连接输出的长度cell_size*cell_size*(num_class+boxes_per_cell*5)二维tensor(还有一个维度是图片数目N)
    # YOLO论文中的7*7*(20+5*2)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值