语法
tf.keras.layers.Embedding(
input_dim, output_dim, embeddings_initializer='uniform',
embeddings_regularizer=None, activity_regularizer=None,
embeddings_constraint=None, mask_zero=False, input_length=None, **kwargs
)
该函数是神经网络的Embedding层,其仅可用于神经网络的第一层。
import numpy as np
model = tf.keras.Sequential()
model.add(tf.keras.layers.Embedding(1000, 64, input_length=10))
input_array = np.random.randint(1000, size=(32, 10))
model.compile('rmsprop', 'mse')
output_array = model.predict(input_array)
print(output_array.shape)
# (32, 10, 64)
参数
参数 | 意义 |
---|---|
input_dim | [tf.int64 ]词汇表的大小,即最大整数索引+1。 |
output_dim | [tf.int64 ]稠密嵌入的维数。 |
embeddings_initializer | 嵌入矩阵的初始值设定项(参见keras.regularizers )。 |
embeddings_regularizer | 应用于嵌入矩阵的正则化器函数(参见keras.regularizers )。 |
embeddings_constraint | 应用于嵌入矩阵的约束函数(参见keras.constraints )。 |
mask_zero | [tf.bool ]无论输入值0是否为应屏蔽的特殊“填充”值。这在使用可能接受可变长度输入的重复层时非常有用。如果这是真的,那么模型中的所有后续层都需要支持掩蔽,否则将引发异常。如果 mask_zero设置为 True`,则索引0不能在词汇表中使用(input_dim应等于词汇表的大小+1)。 |
input_length | 输入序列的长度,当其为常数时。如果要将“展平Dense层”连接到上游,则需要此参数,如果没有此参数,则无法计算密集输出的形状。 |
输入
(batch_size, input_length)
的二维张量。
输出
(batch_size, input_length, output_dim)
的三维张量。
默认情况下,如果GPU可用,则嵌入矩阵将放置在GPU上。这可以实现最佳性能,但可能会导致以下问题:您可能正在使用不支持稀疏GPU内核的优化器。在这种情况下,您将在训练模型时看到一个错误。您的嵌入矩阵可能太大,无法全部读入GPU。在这种情况下,您将看到内存不足(OOM)错误。此时,应该将嵌入矩阵放在CPU内存上。您可以使用设备作用域执行此操作,例如:
with tf.device('cpu:0'):
embedding_layer = Embedding(...)
embedding_layer.build()
函数实现
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-classes-have-attributes
from tensorflow.python.keras import backend
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.layers.Embedding')
class Embedding(Layer):
def __init__(self,
input_dim,
output_dim,
embeddings_initializer='uniform',
embeddings_regularizer=None,
activity_regularizer=None,
embeddings_constraint=None,
mask_zero=False,
input_length=None,
**kwargs):
if 'input_shape' not in kwargs:
if input_length:
kwargs['input_shape'] = (input_length,)
else:
kwargs['input_shape'] = (None,)
if input_dim <= 0 or output_dim <= 0:
raise ValueError('Both `input_dim` and `output_dim` should be positive, '
'found input_dim {} and output_dim {}'.format(
input_dim, output_dim))
if (not base_layer_utils.v2_dtype_behavior_enabled() and
'dtype' not in kwargs):
# In TF1, the dtype defaults to the input dtype which is typically int32,
# so explicitly set it to floatx
kwargs['dtype'] = backend.floatx()
# We set autocast to False, as we do not want to cast floating- point inputs
# to self.dtype. In call(), we cast to int32, and casting to self.dtype
# before casting to int32 might cause the int32 values to be different due
# to a loss of precision.
kwargs['autocast'] = False
super(Embedding, self).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.embeddings_initializer = initializers.get(embeddings_initializer)
self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.embeddings_constraint = constraints.get(embeddings_constraint)
self.mask_zero = mask_zero
self.supports_masking = mask_zero
self.input_length = input_length
@tf_utils.shape_type_conversion
def build(self, input_shape=None):
self.embeddings = self.add_weight(
shape=(self.input_dim, self.output_dim),
initializer=self.embeddings_initializer,
name='embeddings',
regularizer=self.embeddings_regularizer,
constraint=self.embeddings_constraint,
experimental_autocast=False)
self.built = True
def compute_mask(self, inputs, mask=None):
if not self.mask_zero:
return None
return math_ops.not_equal(inputs, 0)
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self.input_length is None:
return input_shape + (self.output_dim,)
else:
# input_length can be tuple if input is 3D or higher
if isinstance(self.input_length, (list, tuple)):
in_lens = list(self.input_length)
else:
in_lens = [self.input_length]
if len(in_lens) != len(input_shape) - 1:
raise ValueError('"input_length" is %s, '
'but received input has shape %s' % (str(
self.input_length), str(input_shape)))
else:
for i, (s1, s2) in enumerate(zip(in_lens, input_shape[1:])):
if s1 is not None and s2 is not None and s1 != s2:
raise ValueError('"input_length" is %s, '
'but received input has shape %s' % (str(
self.input_length), str(input_shape)))
elif s1 is None:
in_lens[i] = s2
return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)
def call(self, inputs):
dtype = backend.dtype(inputs)
if dtype != 'int32' and dtype != 'int64':
inputs = math_ops.cast(inputs, 'int32')
out = embedding_ops.embedding_lookup_v2(self.embeddings, inputs)
if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype:
# Instead of casting the variable as in most layers, cast the output, as
# this is mathematically equivalent but is faster.
out = math_ops.cast(out, self._dtype_policy.compute_dtype)
return out
def get_config(self):
config = {
'input_dim': self.input_dim,
'output_dim': self.output_dim,
'embeddings_initializer':
initializers.serialize(self.embeddings_initializer),
'embeddings_regularizer':
regularizers.serialize(self.embeddings_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'embeddings_constraint':
constraints.serialize(self.embeddings_constraint),
'mask_zero': self.mask_zero,
'input_length': self.input_length
}
base_config = super(Embedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))