论文地址:https://arxiv.org/abs/2004.00626
代码:https://github.com/senguptaumd/Background-Matting
背景介绍
抠图是照片编辑和视觉效果中使用的标准技术,在现有的抠图算法中,要想抠出一个好的maks一般需要三分图(trimap由前景,背景,未知片段组成)。虽然现在也有不需要三分图的算法正在发展,但是这种不需要三分图的算法,在抠图的质量与有三分图的算法没有可比性。
因此,在本算法中除了需要原图片之外,还需要一张额外的背景图片。
抠图算法的公式
I = αF+(1−α)B
F:前景图(foreground), B:背景图(background)。 α:混合系数(mixing coeffcient)。 I :图像的合成方程
当 α 趋近与0的时候,就会获得背景图,相反,当 α 趋近与1时,就会获得前景图。
方法介绍
核心方法
在本文中,核心是使用一个深度抠图网络G,对输入的图片进行前景色和 α 进行提取,对背景色和软分割进行增强,在接上一个鉴别器网络D指导训练生成真实的结果
下面是做的代码展示的一些图像分割效果:
这是resnet50.py
"""ResNet50 model for Keras.
# Reference:
- [Deep Residual Learning for Image Recognition](
https://arxiv.org/abs/1512.03385) (CVPR 2016 Best Paper Award)
Adapted from code contributed by BigMoyan.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import warnings
from . import imagenet_utils
from .imagenet_utils import decode_predictions
from .imagenet_utils import _obtain_input_shape
import tensorflow as tf
preprocess_input = imagenet_utils.preprocess_input
WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/'
'releases/download/v0.2/'
'resnet50_weights_tf_dim_ordering_tf_kernels.h5')
WEIGHTS_PATH_NO_TOP = ('https://github.com/fchollet/deep-learning-models/'
'releases/download/v0.2/'
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
backend = tf.keras.backend
layers = tf.keras.layers
models = tf.keras.models
keras_utils = tf.keras.utils
def identity_block(input_tensor, kernel_size, filters, stage, block):
"""The identity block is the block that has no conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
# Returns
Output tensor for the block.
"""
filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1),
kernel_initializer='he_normal',
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size,
padding='same',
kernel_initializer='he_normal',
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1),
kernel_initializer='he_normal',
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
x = layers.add([x, input_tensor])
x = layers.Activation('relu')(x)
return x
def conv_block(input_tensor,
kernel_size,
filters,
stage,
block,
strides=(2, 2)):
"""A block that has a conv layer at shortcut.
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of
middle conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the first conv layer in the block.
# Returns
Output tensor for the block.
Note that from stage 3,
the first conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well
"""
filters1, filters2, filters3 = filters
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = layers.Conv2D(filters1, (1, 1), strides=strides,
kernel_initializer='he_normal',
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters2, kernel_size, padding='same',
kernel_initializer='he_normal',
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(filters3, (1, 1),
kernel_initializer='he_normal',
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides,
kernel_initializer='he_normal',
name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '1')(shortcut)
x = layers.add([x, shortcut])
x = layers.Activation('relu')(x)
return x
def ResNet50(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
**kwargs):
"""Instantiates the ResNet50 architecture.
Optionally loads weights pre-trained on ImageNet.
Note that the data format convention used by the model is
the one specified in your Keras config at `~/.keras/keras.json`.
# Arguments
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
'imagenet' (pre-training on ImageNet),
or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
or `(3, 224, 224)` (with `channels_first` data format).
It should have exactly 3 inputs channels,
and width and height should be no smaller than 32.
E.g. `(200, 200, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be
the 4D tensor output of the
last convolutional block.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional block, and thus
the output of the model will be a 2D tensor.
- `max` means that global max pooling will
be applied.
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
# Returns
A Keras model instance.
# Raises
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
"""
global backend, layers, models, keras_utils
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
'`None` (random initialization), `imagenet` '
'(pre-training on ImageNet), '
'or the path to the weights file to be loaded.')
if weights == 'imagenet' and include_top and classes != 1000:
raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
' as true, `classes` should be 1000')
# Determine proper input shape
input_shape = _obtain_input_shape(input_shape,
default_size=224,
min_size=32,
data_format=tf.keras.backend.image_data_format(),
require_flatten=include_top,
weights=weights)
if input_tensor is None:
img_input = layers.Input(shape=input_shape)
else:
if not backend.is_keras_tensor(input_tensor):
img_input = layers.Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
x = layers.Conv2D(64, (7, 7),
strides=(2, 2),
padding='valid',
kernel_initializer='he_normal',
name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = layers.Activation('relu')(x)
x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
if include_top:
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)
elif pooling == 'max':
x = layers.GlobalMaxPooling2D()(x)
else:
warnings.warn('The output shape of `ResNet50(include_top=False)` '
'has been changed since Keras 2.2.0.')
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = keras_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
model = models.Model(inputs, x, name='resnet50')
# Load weights.
if weights == 'imagenet':
if include_top:
weights_path = keras_utils.get_file(
'resnet50_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models',
md5_hash='a7b3fe01876f51b976af0dea6bc144eb')
else:
weights_path = keras_utils.get_file(
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
md5_hash='a268eb855778b3df3c7506639542a6af')
model.load_weights(weights_path)
if backend.backend() == 'theano':
keras_utils.convert_all_kernels_in_model(model)
elif weights is not None:
model.load_weights(weights)
return model
这是用deeplab:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import AveragePooling2D, Lambda, Conv2D, Conv2DTranspose, Activation, Reshape, concatenate, Concatenate, BatchNormalization, ZeroPadding2D
from resnet50 import ResNet50
def Upsample(tensor, size):
'''bilinear upsampling'''
name = tensor.name.split('/')[0] + '_upsample'
def bilinear_upsample(x, size):
resized = tf.image.resize(
images=x, size=size)
return resized
y = Lambda(lambda x: bilinear_upsample(x, size),
output_shape=size, name=name)(tensor)
return y
def ASPP(tensor):
'''atrous spatial pyramid pooling'''
dims = K.int_shape(tensor)
y_pool = AveragePooling2D(pool_size=(
dims[1], dims[2]), name='average_pooling')(tensor)
y_pool = Conv2D(filters=256, kernel_size=1, padding='same',
kernel_initializer='he_normal', name='pool_1x1conv2d', use_bias=False)(y_pool)
y_pool = BatchNormalization(name=f'bn_1')(y_pool)
y_pool = Activation('relu', name=f'relu_1')(y_pool)
y_pool = Upsample(tensor=y_pool, size=[dims[1], dims[2]])
y_1 = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same',
kernel_initializer='he_normal', name='ASPP_conv2d_d1', use_bias=False)(tensor)
y_1 = BatchNormalization(name=f'bn_2')(y_1)
y_1 = Activation('relu', name=f'relu_2')(y_1)
y_6 = Conv2D(filters=256, kernel_size=3, dilation_rate=6, padding='same',
kernel_initializer='he_normal', name='ASPP_conv2d_d6', use_bias=False)(tensor)
y_6.set_shape([None, dims[1], dims[2], 256])
y_6 = BatchNormalization(name=f'bn_3')(y_6)
y_6 = Activation('relu', name=f'relu_3')(y_6)
y_12 = Conv2D(filters=256, kernel_size=3, dilation_rate=12, padding='same',
kernel_initializer='he_normal', name='ASPP_conv2d_d12', use_bias=False)(tensor)
y_12.set_shape([None, dims[1], dims[2], 256])
y_12 = BatchNormalization(name=f'bn_4')(y_12)
y_12 = Activation('relu', name=f'relu_4')(y_12)
y_18 = Conv2D(filters=256, kernel_size=3, dilation_rate=18, padding='same',
kernel_initializer='he_normal', name='ASPP_conv2d_d18', use_bias=False)(tensor)
y_18.set_shape([None, dims[1], dims[2], 256])
y_18 = BatchNormalization(name=f'bn_5')(y_18)
y_18 = Activation('relu', name=f'relu_5')(y_18)
y = concatenate([y_pool, y_1, y_6, y_12, y_18], name='ASPP_concat')
y = Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding='same',
kernel_initializer='he_normal', name='ASPP_conv2d_final', use_bias=False)(y)
y = BatchNormalization(name=f'bn_final')(y)
y = Activation('relu', name=f'relu_final')(y)
return y
def DeepLabV3Plus(img_height, img_width):
print('*** Building DeepLabv3Plus Network ***')
base_model = ResNet50(input_shape=(
img_height, img_width, 3), weights='imagenet', include_top=False)
image_features = base_model.get_layer('activation_39').output
x_a = ASPP(image_features)
x_a = Upsample(tensor=x_a, size=[img_height // 4, img_width // 4])
x_b = base_model.get_layer('activation_9').output
x_b = Conv2D(filters=48, kernel_size=1, padding='same',
kernel_initializer='he_normal',
name='low_level_projection', use_bias=False)(x_b)
x_b = BatchNormalization(name=f'bn_low_level_projection')(x_b)
x_b = Activation('relu', name='low_level_activation')(x_b)
x = concatenate([x_a, x_b], name='decoder_concat')
x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu',
kernel_initializer='he_normal',
name='decoder_conv2d_1', use_bias=False)(x)
x = BatchNormalization(name=f'bn_decoder_1')(x)
x = Activation('relu', name='activation_decoder_1')(x)
x = Conv2D(filters=256, kernel_size=3, padding='same', activation='relu',
kernel_initializer='he_normal',
name='decoder_conv2d_2', use_bias=False)(x)
x = BatchNormalization(name=f'bn_decoder_2')(x)
x = Activation('relu', name='activation_decoder_2')(x)
x = Upsample(x, [img_height, img_width])
x = Conv2D(1, 1,( 1), name='output_layer')(x)
x = Activation('sigmoid')(x)
model = Model(inputs=base_model.input, outputs=x, name='DeepLabV3_Plus')
print(f'*** Output_Shape => {model.output_shape} ***')
return model
下面是train.py
import tensorflow as tf
from deeplab_test import DeepLabV3Plus
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
import os
# 查看Tensorflow版本
print('TensorFlow', tf.__version__)
# 配置信息
H, W = 784, 784
batch_size = 12
# 加载数据
train_images = sorted(glob('resized_images/*'))
train_masks = sorted(glob('resized_masks/*'))
val_images = sorted(glob('validation_data/images/*'))
val_masks = sorted(glob('validation_data/masks/*'))
print(f'Found {len(train_images)} training images')
print(f'Found {len(train_masks)} training masks')
print(f'Found {len(val_images)} validation images')
print(f'Found {len(val_masks)} validation masks')
#确保图片与mask一致
for i in range(len(train_masks)):
assert train_images[i].split(
'/')[-1].split('.')[0] == train_masks[i].split('/')[-1].split('.')[0]
for i in range(len(val_masks)):
assert val_images[i].split('/')[-1].split('.')[0] == val_masks[i].split('/')[-1].split('.')[0]
# 预处理图片
def random_scale(image, mask, min_scale=0.65, max_scale=2.5):
random_scale = tf.random.uniform(shape=[1],
minval=min_scale,
maxval=max_scale)
dims = tf.cast(tf.shape(image), dtype=tf.float32)
new_dims = tf.cast(random_scale * dims[:2], dtype=tf.int32)
scaled_image = tf.image.resize(image, size=new_dims, method='bilinear')
scaled_mask = tf.image.resize(mask, size=new_dims, method='nearest')
return scaled_image, scaled_mask
# 处理图片
def pad_inputs(image,
mask,
crop_height=H,
crop_width=H,
ignore_value=255,
pad_value=0):
dims = tf.cast(tf.shape(image), dtype=tf.float32)
h_pad = tf.maximum(1 + crop_height - dims[0], 0)
w_pad = tf.maximum(1 + crop_width - dims[1], 0)
padded_image = tf.pad(image, paddings=[[0, h_pad], [0, w_pad], [
0, 0]], constant_values=pad_value)
padded_mask = tf.pad(mask, paddings=[[0, h_pad], [0, w_pad], [
0, 0]], mode='CONSTANT', constant_values=ignore_value)
return padded_image, padded_mask
def random_crop(image, mask, crop_height=H, crop_width=W):
image_dims = tf.shape(image)
offset_h = tf.random.uniform(
shape=(1,), maxval=image_dims[0] - crop_height, dtype=tf.int32)[0]
offset_w = tf.random.uniform(
shape=(1,), maxval=image_dims[1] - crop_height, dtype=tf.int32)[0]
image = tf.image.crop_to_bounding_box(image,
offset_height=offset_h,
offset_width=offset_w,
target_height=crop_height,
target_width=crop_height)
mask = tf.image.crop_to_bounding_box(mask,
offset_height=offset_h,
offset_width=offset_w,
target_height=crop_height,
target_width=crop_height)
return image, mask
def random_flip(image, mask):
flip = tf.random.uniform(
shape=[1, ], minval=0, maxval=2, dtype=tf.int32)[0]
image = tf.case([
(tf.greater(flip, 0), lambda: tf.image.flip_left_right(image))
], default=lambda: image)
mask = tf.case([
(tf.greater(flip, 0), lambda: tf.image.flip_left_right(mask))
], default=lambda: mask)
return image, mask
#加载图片
def load_image(image_path, mask=False):
img = tf.io.read_file(image_path)
if mask:
img = tf.image.decode_image(img, channels=1)
img.set_shape([None, None, 1])
else:
img = tf.image.decode_image(img, channels=3)
img.set_shape([None, None, 3])
return img
@tf.function()
def preprocess_inputs(image_path, mask_path):
with tf.device('/cpu:0'):
image = load_image(image_path)
mask = load_image(mask_path, mask=True)
mask = tf.cast(mask > 0, dtype=tf.uint8)
image, mask = random_scale(image, mask)
image, mask = pad_inputs(image, mask)
image, mask = random_crop(image, mask)
image, mask = random_flip(image, mask)
image = image[:, :, ::-1] - tf.constant([103.939, 116.779, 123.68])
return image, mask
#创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_masks))
train_dataset = train_dataset.shuffle(1024)
train_dataset = train_dataset.map(map_func=preprocess_inputs,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(batch_size=batch_size, drop_remainder=True)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_masks))
val_dataset = val_dataset.shuffle(512)
val_dataset = val_dataset.map(map_func=preprocess_inputs,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_dataset = val_dataset.batch(batch_size=batch_size, drop_remainder=True)
val_dataset = val_dataset.repeat()
val_dataset = val_dataset.prefetch(tf.data.experimental.AUTOTUNE)
print(train_dataset)
@tf.function()
def dice_coef(y_true, y_pred):
mask = tf.equal(y_true, 255)
mask = tf.logical_not(mask)
y_true = tf.boolean_mask(y_true, mask)
y_pred = tf.boolean_mask(y_pred, mask)
y_true_f = K.flatten(y_true)
y_pred = K.cast(y_pred, 'float32')
y_pred_f = K.cast(K.greater(K.flatten(y_pred), 0.5), 'float32')
intersection = y_true_f * y_pred_f
score = 2. * K.sum(intersection) / (K.sum(y_true_f) + K.sum(y_pred_f))
return score
@tf.function()
def loss(y_true, y_pred):
mask = tf.equal(y_true, 255)
mask = tf.logical_not(mask)
y_true = tf.boolean_mask(y_true, mask)
y_pred = tf.boolean_mask(y_pred, mask)
return tf.losses.binary_crossentropy(y_true, y_pred)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = DeepLabV3Plus(H, W)
#TODO: Regularization loss model.add_loss(regularizer(model.layers[i].kernel))
model.compile(loss=loss,
optimizer=tf.keras.optimizers.Adam(2e-5),
metrics=['accuracy', dice_coef])
tb = TensorBoard(log_dir='logs', write_graph=True, update_freq='batch')
mc = ModelCheckpoint(filepath='top_weights.h5',
monitor='val_dice_coef',
mode='max',
save_best_only='True',
save_weights_only='True', verbose=1)
def learning_rate_fn(epoch):
if epoch < 5:
return 1e-5
elif epoch < 10:
return 2e-5
elif epoch <= 45:
return 1e-5
elif epoch > 45:
return 5e-6
lr_schedule = tf.keras.callbacks.LearningRateScheduler(learning_rate_fn)
callbacks = [mc, tb, lr_schedule]
model.fit(train_dataset,
steps_per_epoch=len(train_images) // batch_size,
epochs=200,
validation_data=val_dataset,
validation_steps=len(val_images) // batch_size,
callbacks=callbacks)
model.save_weights('last_epoch.h5')
下面是一些效果
可以看出效果不是很好,我也会做出改进。