conv2d_transpose()测试

#coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image


mask_path = './tmp/masks/PAAPhoto_70S3A970B3FAV_PAAPhoto20191021184516_1001_-1_-1_2448_3264.png'
img_data = Image.open(mask_path)
annotation = np.array(img_data)
annotation = np.expand_dims(annotation, axis=0)
annotation = np.expand_dims(annotation, axis=3)
annotation = tf.cast(annotation, tf.int32)
print(annotation.shape)

def weight_variable(shape, stddev=0.02, name=None):
    initial = tf.truncated_normal(shape, stddev=stddev)
    if name is None:
        return tf.Variable(initial)
    else:
        return tf.get_variable(name, initializer=initial)

def bias_variable(shape, name=None):
    initial = tf.constant(0.0, shape=shape)
    if name is None:
        return tf.Variable(initial)
    else:
        return tf.get_variable(name, initializer=initial)

def conv2d_transpose_strided(x, W, b, output_shape=None, stride=2):
    if output_shape is None:
        output_shape = x.get_shape().as_list()
        output_shape[1] *= 2
        output_shape[2] *= 2
        output_shape[3] = W.get_shape().as_list()[2]
    conv = tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, stride, stride, 1], padding="SAME")
    return tf.nn.bias_add(conv, b)

kernel_19 = weight_variable(shape=[3, 3, shape_14[3].value, shape_19[3].value], name="kernel_19")
bais_19 = bias_variable(shape=[shape_14[3].value], name="bais_19")
up_19 = conv2d_transpose_strided(layer_19, kernel_19, bais_19, output_shape=tf.shape(layer_14))
add_19_14 = tf.add(up_19, layer_14, name="add_19_14")
bn_19_14 = tf.layers.batch_normalization(add_19_14, name="bn_19_14")
out_19_14 = tf.nn.relu(bn_19_14)

kernel_14 = weight_variable(shape=[3, 3, shape_7[3].value, shape_14[3].value], name="kernel14")
bais_14 = bias_variable(shape=[shape_7[3].value], name="bais_17")
up_14 = conv2d_transpose_strided(out_19_14, kernel_14, bais_14, output_shape=tf.shape(layer_7))
add_14_7 = tf.add(up_14, layer_7, name="add_14_7")
bn_14_7 = tf.layers.batch_normalization(add_14_7, name="bn_14_7")
out_14_7 = tf.nn.relu(bn_14_7)

kernel_7 = weight_variable(shape=[3, 3, shape_4[3].value, shape_7[3].value], name="kernel_7")
bais_7 = bias_variable(shape=[shape_4[3].value], name="bais_7")
up_7 = conv2d_transpose_strided(out_14_7, kernel_7, bais_7, output_shape=tf.shape(layer_4))
add_7_4 = tf.add(up_7, layer_4, name="add_7_4")
bn_7_4 = tf.layers.batch_normalization(add_7_4, name="bn_7_4")
out_7_4 = tf.nn.relu(bn_7_4)

kernel_4 = weight_variable(shape=[3, 3, shape_2[3].value, shape_4[3].value], name="kernel_4")
bais_4 = bias_variable(shape=[shape_2[3].value], name="bais_4")
up_4 = conv2d_transpose_strided(out_7_4, kernel_4, bais_4, output_shape=tf.shape(layer_2))
add_4_2 = tf.add(up_4, layer_2, name="add_4_2")
bn_4_2 = tf.layers.batch_normalization(add_4_2, name="bn_4_2")
out_4_2 = tf.nn.relu(bn_4_2)

kernel_2 = weight_variable(shape=[3, 3, 2, shape_2[3].value], name="kernel_2")
bais_2 = bias_variable(shape=[2], name="bais_2")
up_2 = conv2d_transpose_strided(out_4_2, kernel_2, bais_2, output_shape=[shape_2[0].value, shape_2[1].value * 2, shape_2[2].value * 2, 2])
masks = tf.argmax(up_2, dimension=3, name="mask")

global_step = tf.train.get_or_create_global_step()
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=up_2,
                                                                  labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                  name="entropy"))

with tf.Session() as sess:
    tf.global_variables_initializer().run(session=sess)

    print("layer_19.shape: {}".format(layer_19.get_shape()))
    up_19_value = sess.run(up_19)
    print("up_19.shape: {}".format(up_19_value.shape))
    add_19_14_value = sess.run(add_19_14)
    print("add_19_14_value.shape: {}".format(add_19_14_value.shape))
    out_19_14_value = sess.run(out_19_14)
    print("out_19_14_value.shape: {}".format(out_19_14_value.shape))

    up_14_value = sess.run(up_14)
    print("up_14.shape: {}".format(up_14_value.shape))
    add_14_7_value = sess.run(add_14_7)
    print("add_14_7_value.shape: {}".format(add_14_7_value.shape))
    out_14_7_value = sess.run(out_14_7)
    print("out_14_7_value.shape: {}".format(out_14_7_value.shape))

    up_7_value = sess.run(up_7)
    print("up_7.shape: {}".format(up_7_value.shape))
    add_7_4_value = sess.run(add_7_4)
    print("add_7_4_value.shape: {}".format(add_7_4_value.shape))
    out_7_4_value = sess.run(out_7_4)
    print("out_7_4_value.shape: {}".format(out_7_4_value.shape))

    up_4_value = sess.run(up_4)
    print("up_4.shape: {}".format(up_4_value.shape))
    add_4_2_value = sess.run(add_4_2)
    print("add_4_2_value.shape: {}".format(add_4_2_value.shape))
    out_4_2_value = sess.run(out_4_2)
    print("out_4_2_value.shape: {}".format(out_4_2_value.shape))

    #shape_0_value = sess.run(shape_0)
    #print("shape_0: {}".format(shape_0_value))
    up_2_value = sess.run(up_2)
    print("up_2.shape: {}".format(up_2_value.shape))

    mask_value = sess.run(masks)
    print("mask_value.shape: {}".format(mask_value.shape))
    print(mask_value[0, 50, :])

    loss_value = sess.run(loss)
    print("loss_value: {}".format(loss_value))

reference: https://github.com/ferryer/FCN-tensorflow-hzp

发布了15 篇原创文章 · 获赞 4 · 访问量 1万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 大白 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览