添加mask操作:
class AddMask(keras.layers.Layer):
def __init__(self, mask=None, **kwargs):
super(AddMask, self).__init__(**kwargs)
self.supports_masking = True
# self.supports_masking = False
self.mask = mask
def compute_mask(self, inputs, mask=None):
return self.mask
def call(self, inputs, **kwargs):
"""This is where the layer's logic lives.
# Arguments
inputs: Input tensor, or list/tuple of input tensors.
**kwargs: Additional keyword arguments.
# Returns
A tensor or list/tuple of tensors.
"""
return inputs
去掉mask操作:
class RemoveMask(keras.layers.Layer):
def __init__(self, **kwargs):
super(RemoveMask, self).__init__(**kwargs)
self.supports_masking = True
def compute_mask(self, inputs, mask=None):
return None
def call(self, inputs, **kwargs):
"""This is where the layer's logic lives.
# Arguments
inputs: Input tensor, or list/tuple of input tensors.
**kwargs: Additional keyword arguments.
# Returns
A tensor or list/tuple of tensors.
"""
return inputs