rpn

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Apr  6 15:08:11 2019

@author: fanzy
"""

import keras.layers as KL
import keras.backend as K
import tensorflow as tf
from keras.models import Model

def building_block(filters,block):
    if block!=0:
        stride=1
    else:
        stride=2
    def f(x):
        y=KL.Conv2D(filters,(1,1),strides=stride)(x)
        y=KL.BatchNormalization(axis=3)(y)#tensorflow 3rd dimension
        y=KL.Activation('relu')(y)
        
        y=KL.Conv2D(filters,(3,3),padding='same')(y)
        y=KL.BatchNormalization(axis=3)(y)
        y=KL.Activation('relu')(y)
        
        y=KL.Conv2D(4*filters,(1,1))(y)
        y=KL.BatchNormalization(axis=3)(y)
        
        if block==0:
            shortcut=KL.Conv2D(4*filters,(1,1),strides=stride)(x)
            shortcut=KL.BatchNormalization(axis=3)(shortcut)
        else:
            shortcut=x
        y=KL.Add()([y,shortcut])
        y=KL.Activation('relu')(y)
        return y
    return f
def resnet_featureExtractor(inputs):
    x=KL.Conv2D(64,(3,3),padding='same')(inputs)
    x=KL.BatchNormalization(axis=3)(x)
    x=KL.Activation('relu')(x)
    
    filters=64
    blocks=[3,6,4]
    for i ,block_num in enumerate(blocks):
        for block_id in range(block_num):
            x=building_block(filters,block_id)(x)#blockid=0,1,2...
        filters*=2
    return x


    
def rpn_net(inputs,k):#k=anchors num
    shared_map=KL.Conv2D(256,(3,3),padding='same')(inputs)
    shared_map=KL.Activation('linear')(shared_map)
    
    rpn_class=KL.Conv2D(2*k,(1,1))(shared_map)
    rpn_class=KL.Lambda(lambda x: tf.reshape(x,[tf.shape(x)[0],-1,2]))(rpn_class)
    rpn_class=KL.Activation('linear')(rpn_class)
    rpn_prob=KL.Activation('softmax')(rpn_class)
    
    xy=KL.Conv2D(4*k,(1,1))(shared_map)
    xy=KL.Activation('linear')(xy)
    #zidingyi reshape layer
    rpn_bbox=KL.Lambda(lambda x: tf.reshape(x,[tf.shape(x)[0],-1,4]))(xy)
    
    return rpn_class,rpn_prob,rpn_bbox


x=KL.Input((64,64,3))
y=resnet_featureExtractor(x)

rpn_class,rpn_prob,rpn_bbox=rpn_net(y,9)

model=Model([x],[rpn_class,rpn_prob,rpn_bbox])
model.summary()
        
       

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值