深入浅出TensorFlow2函数——tf.keras.layers.Embedding

分类目录:《深入浅出TensorFlow2函数》总目录


语法
tf.keras.layers.Embedding(
    input_dim, output_dim, embeddings_initializer='uniform',
    embeddings_regularizer=None, activity_regularizer=None,
    embeddings_constraint=None, mask_zero=False, input_length=None, **kwargs
)

该函数是神经网络的Embedding层,其仅可用于神经网络的第一层。

import numpy as np
model = tf.keras.Sequential()
model.add(tf.keras.layers.Embedding(1000, 64, input_length=10))
input_array = np.random.randint(1000, size=(32, 10))
model.compile('rmsprop', 'mse')
output_array = model.predict(input_array)
print(output_array.shape)
# (32, 10, 64)
参数
参数意义
input_dim[tf.int64]词汇表的大小,即最大整数索引+1。
output_dim[tf.int64]稠密嵌入的维数。
embeddings_initializer嵌入矩阵的初始值设定项(参见keras.regularizers)。
embeddings_regularizer应用于嵌入矩阵的正则化器函数(参见keras.regularizers)。
embeddings_constraint应用于嵌入矩阵的约束函数(参见keras.constraints)。
mask_zero[tf.bool]无论输入值0是否为应屏蔽的特殊“填充”值。这在使用可能接受可变长度输入的重复层时非常有用。如果这是真的,那么模型中的所有后续层都需要支持掩蔽,否则将引发异常。如果mask_zero设置为True`,则索引0不能在词汇表中使用(input_dim应等于词汇表的大小+1)。
input_length输入序列的长度,当其为常数时。如果要将“展平Dense层”连接到上游,则需要此参数,如果没有此参数,则无法计算密集输出的形状。
输入

(batch_size, input_length)的二维张量。

输出

(batch_size, input_length, output_dim)的三维张量。

默认情况下,如果GPU可用,则嵌入矩阵将放置在GPU上。这可以实现最佳性能,但可能会导致以下问题:您可能正在使用不支持稀疏GPU内核的优化器。在这种情况下,您将在训练模型时看到一个错误。您的嵌入矩阵可能太大,无法全部读入GPU。在这种情况下,您将看到内存不足(OOM)错误。此时,应该将嵌入矩阵放在CPU内存上。您可以使用设备作用域执行此操作,例如:

with tf.device('cpu:0'):
  embedding_layer = Embedding(...)
  embedding_layer.build()
函数实现
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# pylint: disable=g-classes-have-attributes

from tensorflow.python.keras import backend
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import keras_export


@keras_export('keras.layers.Embedding')
class Embedding(Layer):


  def __init__(self,
               input_dim,
               output_dim,
               embeddings_initializer='uniform',
               embeddings_regularizer=None,
               activity_regularizer=None,
               embeddings_constraint=None,
               mask_zero=False,
               input_length=None,
               **kwargs):
    if 'input_shape' not in kwargs:
      if input_length:
        kwargs['input_shape'] = (input_length,)
      else:
        kwargs['input_shape'] = (None,)
    if input_dim <= 0 or output_dim <= 0:
      raise ValueError('Both `input_dim` and `output_dim` should be positive, '
                       'found input_dim {} and output_dim {}'.format(
                           input_dim, output_dim))
    if (not base_layer_utils.v2_dtype_behavior_enabled() and
        'dtype' not in kwargs):
      # In TF1, the dtype defaults to the input dtype which is typically int32,
      # so explicitly set it to floatx
      kwargs['dtype'] = backend.floatx()
    # We set autocast to False, as we do not want to cast floating- point inputs
    # to self.dtype. In call(), we cast to int32, and casting to self.dtype
    # before casting to int32 might cause the int32 values to be different due
    # to a loss of precision.
    kwargs['autocast'] = False
    super(Embedding, self).__init__(**kwargs)

    self.input_dim = input_dim
    self.output_dim = output_dim
    self.embeddings_initializer = initializers.get(embeddings_initializer)
    self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
    self.activity_regularizer = regularizers.get(activity_regularizer)
    self.embeddings_constraint = constraints.get(embeddings_constraint)
    self.mask_zero = mask_zero
    self.supports_masking = mask_zero
    self.input_length = input_length

  @tf_utils.shape_type_conversion
  def build(self, input_shape=None):
    self.embeddings = self.add_weight(
        shape=(self.input_dim, self.output_dim),
        initializer=self.embeddings_initializer,
        name='embeddings',
        regularizer=self.embeddings_regularizer,
        constraint=self.embeddings_constraint,
        experimental_autocast=False)
    self.built = True

  def compute_mask(self, inputs, mask=None):
    if not self.mask_zero:
      return None
    return math_ops.not_equal(inputs, 0)

  @tf_utils.shape_type_conversion
  def compute_output_shape(self, input_shape):
    if self.input_length is None:
      return input_shape + (self.output_dim,)
    else:
      # input_length can be tuple if input is 3D or higher
      if isinstance(self.input_length, (list, tuple)):
        in_lens = list(self.input_length)
      else:
        in_lens = [self.input_length]
      if len(in_lens) != len(input_shape) - 1:
        raise ValueError('"input_length" is %s, '
                         'but received input has shape %s' % (str(
                             self.input_length), str(input_shape)))
      else:
        for i, (s1, s2) in enumerate(zip(in_lens, input_shape[1:])):
          if s1 is not None and s2 is not None and s1 != s2:
            raise ValueError('"input_length" is %s, '
                             'but received input has shape %s' % (str(
                                 self.input_length), str(input_shape)))
          elif s1 is None:
            in_lens[i] = s2
      return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)

  def call(self, inputs):
    dtype = backend.dtype(inputs)
    if dtype != 'int32' and dtype != 'int64':
      inputs = math_ops.cast(inputs, 'int32')
    out = embedding_ops.embedding_lookup_v2(self.embeddings, inputs)
    if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype:
      # Instead of casting the variable as in most layers, cast the output, as
      # this is mathematically equivalent but is faster.
      out = math_ops.cast(out, self._dtype_policy.compute_dtype)
    return out

  def get_config(self):
    config = {
        'input_dim': self.input_dim,
        'output_dim': self.output_dim,
        'embeddings_initializer':
            initializers.serialize(self.embeddings_initializer),
        'embeddings_regularizer':
            regularizers.serialize(self.embeddings_regularizer),
        'activity_regularizer':
            regularizers.serialize(self.activity_regularizer),
        'embeddings_constraint':
            constraints.serialize(self.embeddings_constraint),
        'mask_zero': self.mask_zero,
        'input_length': self.input_length
    }
    base_config = super(Embedding, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))
  • 10
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

von Neumann

您的赞赏是我创作最大的动力~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值