FCN—tensorflow版本代码超详解

本文介绍了使用TensorFlow实现FCN的详细过程,包括FCN.py中vggnet函数、inference函数及main主函数的解析,并提供了BatchDatestreader.py和read_MITSceneParsingData.py的数据读取方法,旨在帮助读者深入理解FCN的实现并掌握大厂面试技巧。
摘要由CSDN通过智能技术生成

代码共有四个文件,分别如下:
FCN.py
vggnet函数:

# 根据载入的权重建立原始的 VGGNet 的网络
def vgg_net(weights, image):
layers = (
    'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
    'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
    'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
    'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
    'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4'
)
net = {}
current = image
for i, name in enumerate(layers):
    kind = name[:4]
    if kind == 'conv':
        kernels, bias = weights[i][0][0][0][0]
        # matconvnet: weights are [width, height, in_channels, out_channels]
        # tensorflow: weights are [height, width, in_channels, out_channels]
        kernels = utils.get_variable(np.transpose(kernels, (1, 0, 2, 3)), name=name + "_w")  # k值转置
        bias = utils.get_variable(bias.reshape(-1), name=name + "_b")  # b转成一维
        current = utils.conv2d_basic(current, kernels, bias)  # 图像卷积完加b
        print("当前形状:", np.shape(current))
    elif kind == 'relu':
        current = tf.nn.relu(current, name=name)
        if debug:
            utils.add_activation_summary(current)
    elif kind == 'pool':
        current = utils.avg_pool_2x2(current) # 平均池化
        print("当前形状:", np.shape(current))
    net[name] = current
return net

inference函数:

 # FCN的网络结构定义,网络中用到的参数是迁移VGG训练好的参数
def inference(image, keep_prob):  # 输入图像和dropout值
    """
    Semantic segmentation network definition
    :param image: input image. Should have values in range 0-255
    :param keep_prob:
    :return:
    """
    # 加载模型数据,获得标准化均值
    print("原始图像:", np.shape(image))
    model_data = utils.get_model_data(model_path)
    mean = model_data['normalization'][0][0][0]  # 通过字典获取mean值,vgg模型参数里有normaliza这个字典,三个0用来去虚维找到mean
    mean_pixel = np.mean(mean, axis=(0, 1))
    weights = np.squeeze(model_data['layers'])  # 从数组的形状中删除单维度条目,获得vgg权重

    # 图像预处理
    processed_image = utils.process_image(image, mean_pixel)  # 图像减平均值实现标准化
    print("预处理后的图像:", np.shape(processed_image))

    with tf.variable_scope("inference"):
        # 建立原始的VGGNet-19网络

        print("开始建立VGG网络:")
        image_net = vgg_net(weights, processed_image)

        # 在VGGNet-19之后添加 一个池化层和三个卷积层
        conv_final_layer = image_net["conv5_3"]  # 14*14*512
        print("VGG处理后的图像:", np.shape(conv_final_layer))

        pool5 = utils.max_pool_2x2(conv_final_layer)  # w,h/32 =7*7*512

        print("pool5:", np.shape(pool5))

        W6 = utils.weight_variable([7, 7, 512, 4096], name="W6")
        b6 = utils.bias_variable([4096], name="b6")
        conv6 = utils.conv2d_basic(pool5, W6, b6)  # 1*1*4096
        relu6 = tf.nn.relu(conv6, name="relu6")
        if debug:
            utils.add_activation_summary(relu6)
        relu_dropout6 = tf.nn.dropout(relu6, keep_prob=keep_prob)

        print("conv6:", np.shape(relu_dropout6))

        W7 = utils.weight_variable([1, 1, 4096, 4096], name="W7")
        b7 = utils.bias_variable([4096], name="b7")
        conv7 = utils.conv2d_basic(relu_dropout6, W7,
  • 4
    点赞
  • 72
    收藏
    觉得还不错? 一键收藏
  • 51
    评论
评论 51
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值