unet实现区域分割

https://github.com/zonghaofan/pig-seg/tree/master/disk_segmentation

网络架构:

# coding:utf-8
import tensorflow as tf
import cv2
import numpy as np
import matplotlib.pyplot as plt

img = cv2.imread('./data/test.png')

# cv2.imshow('1.jpg',img)
# cv2.waitKey(0)
img = cv2.resize(img, (1024, 1024))
img = np.array(img).astype(np.float32)
img = img[np.newaxis, ...]
print(img.shape)

x_input = tf.placeholder(shape=[None, 1024, 1024, 3], dtype=tf.float32)

# x=tf.random_normal(shape=[1,1024,1024,3],dtype=tf.float32)
n_filters = [8, 8]


# name=1
def conv2d(x, n_filters, training, name, pool=True, activation=tf.nn.relu):
    with tf.variable_scope('layer{}'.format(name)):
        for index, filter in enumerate(n_filters):
            conv = tf.layers.conv2d(x, filter, (3, 3), strides=1, padding='same', activation=None,
                                    name='conv_{}'.format(index + 1))
            conv = tf.layers.batch_normalization(conv, training=training, name='bn_{}'.format(index + 1))
            conv = activation(conv, name='relu{}_{}'.format(name, index + 1))

        if pool is False:
            return conv

        pool = tf.layers.max_pooling2d(conv, pool_size=(2, 2), strides=2, name='pool_{}'.format(name))

        return conv, pool


def upsampling_2d(tensor, name, size=(2, 2)):
    h_, w_, c_ = tensor.get_shape().as_list()[1:]
    h_multi, w_multi = size
    h = h_multi * h_
    w = w_multi * w_
    target = tf.image.resize_nearest_neighbor(tensor, size=(h, w), name='upsample_{}'.format(name))

    return target


def upsampling_concat(input_A, input_B, name):
    upsampling = upsampling_2d(input_A, name=name, size=(2, 2))
    up_concat = tf.concat([upsampling, input_B], axis=-1, name='up_concat_{}'.format(name))

    return up_concat

def unet(input):
    #归一化 -1~1
    input=(input-127.5)/127.5
    conv1, pool1 = conv2d(input, [8, 8], training=True, name=1)
    print(conv1.shape)
    print(pool1.shape)
    conv2, pool2 = conv2d(pool1, [16, 16], training=True, name=2)
    print(conv2.shape)
    print(pool2.shape)
    conv3, pool3 = conv2d(pool2, [32, 32], training=True, name=3)
    print(conv3.shape)
    print(pool3.shape)
    conv4, pool4 = conv2d(pool3, [64, 64], training=True, name=4)
    print(conv4.shape)
    print(pool4.shape)
    conv5 = conv2d(pool4, [128, 128], training=True, pool=False, name=5)
    print(conv5.shape)

    up6 = upsampling_concat(conv5, conv4, name=6)
    print('up6', up6.shape)
    conv6 = conv2d(up6, [64, 64], training=True, pool=False, name=6)
    print(conv6.shape)
    up7 = upsampling_concat(conv6, conv3, name=7)
    print('up7', up7.shape)
    conv7 = conv2d(up7, [32, 32], training=True, pool=False, name=7)
    print(conv7.shape)
    up8 = upsampling_concat(conv7, conv2, name=8)
    print('up8', up8.shape)
    conv8 = conv2d(up8, [16, 16], training=True, pool=False, name=8)
    print(conv8.shape)
    up9 = upsampling_concat(conv8, conv1, name=9)
    print('up9', up9.shape)
    conv9 = conv2d(up9, [8, 8], training=True, pool=False, name=9)
    print(conv9.shape)
    final = tf.layers.conv2d(conv9, 1, (1, 1), name='final', activation=tf.nn.sigmoid, padding='same')
    print('final', final.shape)
    return final


if __name__ == '__main__':
    final=unet(x_input)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        y_final = sess.run(final, feed_dict={x_input: img})
    result = y_final[0, ...]
    print(result.shape)
    print(result[...,:10])


    # result=cv2.imread('./2.jpg')
    # result=cv2.resize(result,(640,640))
    # print(result)
    cv2.imshow('1.jpg', result)
    cv2.waitKey(0)

打印结果:这里打印值有小数,故直接imshow就是输出图,而如果imwrite,查看图片的值全是0,1,轮廓也能看清,只不过不是很清晰。

输入:

输出:截图没有完全

 

  • 3
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值