解析Torch中 `Embedding`

依次关系:
Embedding
( 1 )-> Parameter
( 2 )-> embedding -> has_torch_function_variadic -> handle_torch_function-> _get_overloaded_args

首先详细解析(1)和(2),最后基于上述知识解析Embedding

一. 开始解析(1)和(2):

( 1 )-> Parameter

class Parameter(torch.Tensor):
    r"""一种被视为模块参数的Tensor类型。

    参数是 :class:~torch.Tensor 的子类,当与 :class:Module 一起使用时,具有非常特殊的一个特性——当它们作为 :class:Module 的属性被赋值时,
    会自动被添加到该模块的参数列表中,并且会出现在如 :meth:~Module.parameters 的迭代器中。直接赋值一个普通的Tensor并不具备这样的效果。
    这是因为有时我们可能希望在模型中缓存一些临时状态,比如RNN的上一次隐藏状态。如果没有 :class:Parameter 这样的类,
    这些临时变量也会被注册为模型的参数。

    参数:
        data (Tensor): 参数的Tensor数据。
        requires_grad (布尔值, 可选): 表示该参数是否需要计算梯度。更多细节请参考 :ref:locally-disable-grad-doc。默认值: True
    """
    def __new__(cls, data=None, requires_grad=True):
        if data is None:
            data = torch.tensor([])
        return torch.Tensor._make_subclass(cls, data, requires_grad)

    def __deepcopy__(self, memo):
        if id(self) in memo:
            return memo[id(self)]
        else:
            result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
            memo[id(self)] = result
            return result

    def __repr__(self):
        return 'Parameter containing:\n' + super(Parameter, self).__repr__()

    def __reduce_ex__(self, proto):
        # See Note [Don't serialize hooks]
        return (
            torch._utils._rebuild_parameter,
            (self.data, self.requires_grad, OrderedDict())
        )

    __torch_function__ = _disabled_torch_function_impl

(a)__new__

def __new__(cls, data=None, requires_grad=True):
    if data is None:
        data = torch.tensor([])
    return torch.Tensor._make_subclass(cls, data, requires_grad)

__new__方法在Python中是一个特殊的方法,用于控制一个类的实例化过程。在PyTorch中,Parameter类的__new__方法被重写以定制Parameter实例的创建过程。让我们深入分析这段代码:

def __new__(cls, data=None, requires_grad=True):

这里定义了Parameter类的__new__方法,它接受两个参数:

data:这是用于初始化Parameter对象的数据,可以是任何可以转换为Tensor的对象,默认为None。
requires_grad:一个布尔值,表示Parameter对象是否需要计算梯度,默认为True。

if data is None: data = torch.tensor([])

首先,检查data参数是否为None。如果是None,则创建一个空的Tensor。这确保即使用户没有提供初始数据,Parameter也可以被正确初始化。

return torch.Tensor._make_subclass(cls, data, requires_grad)

这是__new__方法的核心部分。torch.Tensor._make_subclass是一个内部方法,用于从现有的Tensor创建一个子类实例。在这个方法调用中:

cls:指的是Parameter类本身,这告诉_make_subclass方法创建一个Parameter类的实例。
data:这是用于初始化Parameter的数据,它应该是已经转换为Tensor的对象。
requires_grad:一个布尔值,表示这个Parameter实例是否需要计算梯度。
通过调用torch.Tensor._make_subclass,Parameter类的__new__方法实际上是在创建一个Tensor的子类实例,这个子类继承自Parameter类,并且具有datarequires_grad属性。

这个过程确保了Parameter实例不仅具有普通Tensor的所有属性和方法,而且还具有Parameter类的特有功能,比如自动被加入到所属Module的参数列表中,从而可以在训练过程中被优化算法正确地识别和更新。

总结来说,Parameter类的__new__方法通过torch.Tensor._make_subclass创建了一个既有Tensor特性又有Parameter特性的对象,这在PyTorch的模型构建和训练过程中起到了关键作用。
(b) __deepcopy__

def __deepcopy__(self, memo):
    if id(self) in memo:
        return memo[id(self)]
    else:
        result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
        memo[id(self)] = result
        return result

__deepcopy__方法是Python的内置方法,用于支持深度拷贝(deep copy)操作。深度拷贝意味着创建一个对象的完全独立副本,包括所有嵌套的对象。在PyTorch中,Parameter类的__deepcopy__方法实现了对Parameter对象的深度拷贝,确保拷贝的Parameter对象与其原对象在内存中是完全分离的。

让我们逐步解析__deepcopy__方法的实现:

if id(self) in memo:

这里的id(self)获取的是当前Parameter对象的唯一内存地址。memo是一个字典,用于存储已经拷贝过的对象,以避免重复拷贝相同的对象。这一行代码检查当前Parameter对象是否已经被拷贝过了。

return memo[id(self)]

如果当前Parameter对象已经在memo字典中,说明它之前已经被拷贝过,那么就直接返回这个已经拷贝好的对象,避免重复拷贝。

else:

如果当前Parameter对象还没有被拷贝过,那么就进入这个分支。

result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)

这一行代码是创建一个新的Parameter对象的关键。它做了两件事:

使用type(self)获取当前Parameter对象的类型,也就是Parameter类本身,这样可以创建一个新的同类型的Parameter对象。
调用self.data.clone(memory_format=torch.preserve_format)来创建data属性的深拷贝。memory_format=torch.preserve_format参数确保拷贝后的Tensor的内存布局与原Tensor相同。这是深度拷贝的核心,确保了data属性的独立性。
设置新Parameter对象的requires_grad属性,确保拷贝的Parameter对象与原对象的梯度计算需求一致。

memo[id(self)] = result

将新创建的Parameter对象存储到memo字典中,使用当前Parameter对象的内存地址作为键,这样下次如果遇到相同的对象就可以直接从memo中取出拷贝结果。

return result

最后,返回新创建的Parameter对象。

总的来说,__deepcopy__方法确保了在深度拷贝操作中,Parameter对象及其内部的data属性都被正确地独立复制,同时通过memo字典避免了不必要的重复拷贝,提高了效率。
( c ) __repr__

def __repr__(self):
    return 'Parameter containing:\n' + super(Parameter, self).__repr__()

( d ) __reduce_ex__
这是Parameter类的字符串表示方法,用于返回Parameter的可读性描述。它返回一个字符串,其中包含了Parameter的描述信息,以及内部Tensor的字符串表示。

def __reduce_ex__(self, proto):
    # See Note [Don't serialize hooks]
    return (
        torch._utils._rebuild_parameter,
        (self.data, self.requires_grad, OrderedDict())
    )

__reduce_ex__是Python对象序列化接口的一个重要组成部分,它主要用于支持Python的pickle模块进行对象的序列化和反序列化。pickle模块是Python内置的用于序列化和反序列化复杂对象的工具,它允许将Python对象保存到磁盘文件中或者在网络上传输,之后可以恢复为原来的对象状态。__reduce_ex__方法的实现对于支持pickle模块的序列化至关重要,尤其是在处理像Parameter这样的复杂对象时。

解析__reduce_ex__方法

def __reduce_ex__(self, proto):

proto参数是序列化协议的版本号,__reduce_ex__方法需要根据这个版本号来确定序列化的策略。但在大多数情况下,proto参数并不会直接影响序列化的过程。

See Note [Don’t serialize hooks]
这个注释指向了PyTorch源代码中的一个注释,标题为[Don’t serialize hooks]。这表明在序列化Parameter对象时,不应该序列化与之关联的任何钩子(hooks)。钩子在PyTorch中用于在前向或后向传播过程中执行额外的操作,但是它们通常是动态的,不应该被硬编码到序列化后的对象中。因此,这里暗示了在序列化Parameter时需要小心处理,以避免序列化钩子。

return (torch._utils._rebuild_parameter, (self.data, self.requires_grad, OrderedDict()))

这行代码是__reduce_ex__方法的主体,它返回一个元组,用于指导pickle模块如何序列化和反序列化Parameter对象:

torch._utils._rebuild_parameter:这是一个内部函数,用于在反序列化时重建Parameter对象。当pickle模块在反序列化过程中遇到这个元组时,它会调用_rebuild_parameter函数,并传入后面元组中的参数。
(self.data, self.requires_grad, OrderedDict()):这是用于重建Parameter对象所需的数据。它包括:
self.data:Parameter对象的data属性,这是存储实际数值的Tensor。
self.requires_grad:一个布尔值,表示Parameter对象是否需要计算梯度。
OrderedDict():一个空的有序字典。在PyTorch的序列化过程中,这里原本可能会包含其他信息,如钩子等,但由于前面的注释提示,这里传递的是一个空字典,以确保不会序列化钩子。
总结
__reduce_ex__方法通过指定一个重建函数和必要的参数,确保了Parameter对象在序列化和反序列化过程中的正确性。它避免了序列化钩子,同时保留了Parameter对象的核心属性,如data和requires_grad,从而在保存和恢复模型时能够保持一致性。
( e ) __torch_function__

__torch_function__ = _disabled_torch_function_impl

这是Parameter类的一个属性,用于禁用Parameter的__torch_function__机制。在PyTorch中,__torch_function__是用于操作子类化的Tensor的特殊机制,但是在这里被禁用,以确保Parameter在与其他Tensor进行运算时的行为符合预期。

综上所述,这些方法共同确保了Parameter类的正确创建、深拷贝、字符串表示和序列化行为,同时也确保了在与其他Tensor进行运算时,Parameter的行为是可控和一致的。

( 2 )-> embedding -> has_torch_function_variadic -> handle_torch_function-> _get_overloaded_args

def embedding(
    input: Tensor,
    weight: Tensor,
    padding_idx: Optional[int] = None,
    max_norm: Optional[float] = None,
    norm_type: float = 2.0,
    scale_grad_by_freq: bool = False,
    sparse: bool = False,
) -> Tensor:
    r"""一个简单的查找表,用于在固定词典和固定大小中查找嵌入向量。

    该模块常用来依据索引检索词嵌入。模块的输入是一系列索引,以及嵌入矩阵,输出是与这些索引相对应的词嵌入向量。

    更多详情请参见 :class:torch.nn.Embedding 类。

    参数:
        input (LongTensor): 包含指向嵌入矩阵索引的张量
        weight (Tensor): 嵌入矩阵,其行数等于最大可能索引+1,列数等于嵌入向量的维度
        padding_idx (int, 可选): 如果指定了该参数,位置 :attr:padding_idx 的条目不会对梯度有贡献;
        因此,在 :attr:padding_idx 处的嵌入向量在训练过程中不会被更新,
        即保持为一个固定的“填充”不变。
        max_norm (float, 可选): 如果给定,所有大于 :attr:max_norm 范数的嵌入向量将被重归一化,使其范数等于 :attr:max_norm。
        注意:这将就地修改 :attr:weight。
        norm_type (float, 可选): 对于 :attr:max_norm 选项,计算p-范数的p值。默认值为 2。
        scale_grad_by_freq (布尔值, 可选): 如果给定,这将按小批量中单词频率的倒数缩放梯度。默认值为 False。
        sparse (布尔值, 可选): 如果为 True,相对于 :attr:weight 的梯度将是一个稀疏张量。更多关于稀疏梯度的细节请参见
        :class:torch.nn.Embedding 类的注意事项。

    形状:
        - 输入: LongTensor,任意形状,包含要提取的索引
        - 权重: 浮点类型的嵌入矩阵,形状为 (V, embedding_dim),
        其中 V = 最大索引 + 1,embedding_dim = 嵌入向量的维度
        - 输出: (*, embedding_dim),其中 * 是输入的形状

    Examples::

        >>> # a batch of 2 samples of 4 indices each
        >>> input = torch.tensor([[1,2,4,5],[4,3,2,9]])
        >>> # an embedding matrix containing 10 tensors of size 3
        >>> embedding_matrix = torch.rand(10, 3)
        >>> F.embedding(input, embedding_matrix)
        tensor([[[ 0.8490,  0.9625,  0.6753],
                 [ 0.9666,  0.7761,  0.6108],
                 [ 0.6246,  0.9751,  0.3618],
                 [ 0.4161,  0.2419,  0.7383]],

                [[ 0.6246,  0.9751,  0.3618],
                 [ 0.0237,  0.7794,  0.0528],
                 [ 0.9666,  0.7761,  0.6108],
                 [ 0.3385,  0.8612,  0.1867]]])

        >>> # example with padding_idx
        >>> weights = torch.rand(10, 3)
        >>> weights[0, :].zero_()
        >>> embedding_matrix = weights
        >>> input = torch.tensor([[0,2,0,5]])
        >>> F.embedding(input, embedding_matrix, padding_idx=0)
        tensor([[[ 0.0000,  0.0000,  0.0000],
                 [ 0.5609,  0.5384,  0.8720],
                 [ 0.0000,  0.0000,  0.0000],
                 [ 0.6262,  0.2438,  0.7471]]])
    """

    if has_torch_function_variadic(input, weight):
        return handle_torch_function(
            embedding, (input, weight),
            input, weight, padding_idx, max_norm, norm_type,
            scale_grad_by_freq, sparse
        )
    if padding_idx is not None:
        if padding_idx > 0:
            assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings"
        elif padding_idx < 0:
            assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings"
            padding_idx = weight.size(0) + padding_idx
    else:
        padding_idx = -1
    if max_norm is not None:
        # Note [embedding_renorm contiguous]
        # `embedding_renorm_` will call .contiguous() on input anyways, so we
        # call it here and take advantage of the improved locality in the
        # `embedding` call below too.
        input = input.contiguous()
        # Note [embedding_renorm set_grad_enabled]
        # XXX: equivalent to
        # with torch.no_grad():
        #   torch.embedding_renorm_
        # remove once script supports set_grad_enabled
        _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

( a ) has_torch_function_variadic

has_torch_function_variadic = _add_docstr(
    _has_torch_function_variadic,
	r"""这是has_torch_function的一种特殊情况,它跳过了元组的创建。
	
	此方法利用了Python 3.7中引入的METH_FASTCALL协议;对于3.6及更早版本,与has_torch_function相比,它的性能大致相当。
	
	而不是这样调用:
	has_torch_function((a, b))
	而是调用:
	has_torch_function_variadic(a, b)
	这样可以避免不必要的打包和解包工作。
	"""
)

has_torch_function_variadichas_torch_function方法的一个特定优化版本,它专门设计用于避免在处理多个参数时创建元组的开销。在 PyTorch 中,has_torch_function 用于检测传入的参数中是否有对象定义了 __torch_function__,这是用于自定义张量操作的特殊方法。然而,当涉及到多个参数时,常规的 has_torch_function 方法需要将所有参数打包进一个元组,然后再在函数内部解包,这一步骤在某些情况下可能会引入不必要的性能损耗。

has_torch_function_variadic 的引入就是为了克服这一限制。它利用了 Python 3.7 引入的 METH_FASTCALL 协议,该协议优化了对变长参数列表的处理,允许函数直接接收任意数量的位置参数,而无需显式创建元组。因此,相比于 has_torch_functionhas_torch_function_variadic 在处理多参数时可以避免创建和解包元组的开销,从而提高性能。

对于 Python 3.6 及更早版本,has_torch_function_variadic 的性能与 has_torch_function 相当,因为它不依赖于 METH_FASTCALL 协议。然而,即使在这些旧版本的 Python 中,使用 has_torch_function_variadic 依然可以避免不必要的元组操作,从而可能带来轻微的性能提升。

总之,has_torch_function_variadic 是为了提高 PyTorch 在处理多个参数时的性能而设计的,特别是在 Python 3.7 及以上版本中,它通过直接处理变长参数列表,避免了创建和解包元组的额外开销,从而实现了更高效的代码执行。在编写涉及大量张量操作的高性能代码时,选择使用 has_torch_function_variadic 能够帮助开发者进一步优化程序的运行速度。

( b ) handle_torch_function

def handle_torch_function(
        public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
    """Implement a function with checks for ``__torch_function__`` overrides.

    主要作用是在 PyTorch 中处理和调用由用户自定义的 __torch_function__ 方法。
    在 PyTorch 中,__torch_function__ 允许用户自定义数据类型能够与 torch.Tensor 类型交互,使得自定义类型的对象能够在 PyTorch 的函数中作为参数传递并得到适当处理

    See torch::autograd::handle_torch_function for the equivalent of this
    function in the C++ implementation.

    Arguments
    ---------
    public_api : function
        Function exposed by the public torch API originally called like
        ``public_api(*args, **kwargs)`` on which arguments are now being
        checked.
         被调用的公共 API 函数,即原本用户想要调用的函数,如 torch.add 或者任何其他 PyTorch 提供的函数
    relevant_args : iterable
        Iterable of arguments to check for __torch_function__ methods.
        需要检查是否具有 __torch_function__ 方法的参数集合
    args : tuple
        Arbitrary positional arguments originally passed into ``public_api``.
    kwargs : tuple
        Arbitrary keyword arguments originally passed into ``public_api``.
        原始传给 public_api 的位置参数和关键字参数

    Returns
    -------
    object
        Result from calling ``implementation`` or an ``__torch_function__``
        method, as appropriate.

    Raises
    ------
    TypeError : if no implementation is found.

    Example
    -------
    >>> def func(a):
    ...     if type(a) is not torch.Tensor:  # This will make func dispatchable by __torch_function__
    ...         return handle_torch_function(func, (a,), a)
    ...     return a + 0
    """
    # Check for __torch_function__ methods.
    # 检查 __torch_function__ 方法:
    # 使用 _get_overloaded_args 函数找到 relevant_args 中具有 __torch_function__ 方法的参数。这一步是为了确定哪些参数可能需要被特殊处理
    overloaded_args = _get_overloaded_args(relevant_args)
    # overloaded_args already have unique types.
    # 获取重载参数的类型:
    # 对于每一个找到的具有 __torch_function__ 方法的参数,获取其类型。这些类型信息对于后续的 __torch_function__ 方法调用是必要的
    types = tuple(map(type, overloaded_args))

    # Call overrides
    # 调用 __torch_function__ 方法:
    # 遍历所有具有 __torch_function__ 方法的参数,并尝试调用它们的 __torch_function__ 方法。方法调用时传入 public_api、参数类型集合、原始参数和关键字参数。
    # 如果任意一个 __torch_function__ 方法返回的结果不是 NotImplemented,那么直接返回这个结果。这意味着自定义类型已经成功处理了这个函数调用。
    for overloaded_arg in overloaded_args:
        # Use `public_api` instead of `implementation` so __torch_function__
        # implementations can do equality/identity comparisons.
        result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)

        if result is not NotImplemented:
            return result
    # 错误处理:
    # 如果所有的 __torch_function__ 方法都返回 NotImplemented,那么函数会抛出一个 TypeError,指出没有找到适合当前参数类型组合的实现
    func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
    raise TypeError("no implementation found for '{}' on types that implement "
                    '__torch_function__: {}'
                    .format(func_name, [type(arg) for arg in overloaded_args]))

handle_torch_function函数是PyTorch中用于处理涉及__torch_function__方法的函数调用的核心机制。这个函数的主要目的是检测和执行与__torch_function__相关的自定义操作,使得用户可以对特定类型的数据结构定义自定义的数学或张量操作,而不仅仅是传统的Tensor对象。下面是对这个函数的详细解释:

函数定义和参数

def handle_torch_function(
        public_api: Callable, 
        relevant_args: Iterable[Any], 
        *args, 
        **kwargs) -> Any:

public_api: 这是原始调用的公共API函数,通常是一个PyTorch提供的函数或方法,例如torch.add。这个参数是为了让__torch_function__的实现者能够访问原始的函数,以便进行比较或调用。
relevant_args: 这是一个包含需要检查__torch_function__方法的参数的可迭代对象。这些参数是public_api函数的输入参数,用于确定是否存在自定义操作。
*args **kwargs: 这些是传递给public_api函数的原始位置参数和关键字参数。
函数体
检查__torch_function__方法:首先,函数调用_get_overloaded_args来收集所有具有__torch_function__属性的参数。这些参数被称为overloaded_args,它们将用于后续的自定义操作检测。
执行自定义操作:对于每个overloaded_args中的参数,函数尝试调用其__torch_function__方法,传入public_api函数、参数类型、原始位置参数args和关键字参数kwargs。如果__torch_function__方法返回的结果不是NotImplemented,则返回这个结果,这意味着自定义操作已经被成功执行。
异常处理:如果没有任何__torch_function__方法返回有效的结果,函数将抛出一个TypeError异常,指出没有找到针对当前类型组合的实现。
示例
在函数的注释中,有一个示例展示了如何在自定义函数中使用handle_torch_function。如果函数func的输入a不是一个Tensor,那么func将调用handle_torch_function来检查是否存在针对a__torch_function__方法。如果存在,那么func将执行这个自定义操作;如果不存在,func将执行默认操作(在这个例子中是a + 0)。

总结
handle_torch_function是PyTorch中实现高度动态和可扩展的张量操作的关键组件。它允许用户定义自己的数据类型,并为这些类型定义自定义的数学或张量操作,极大地增强了框架的灵活性和适用范围。通过检测和执行__torch_function__方法,handle_torch_function确保了用户自定义操作的正确性和效率,同时也维护了PyTorch核心API的一致性和稳定性。

( c ) _get_overloaded_args

def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
    """返回一个参数列表,用于调用__torch_function__方法。

    本函数会检查relevant_args中适用于__torch_function__实现的参数,按照调用优先级顺序,将这些参数及其类型存储在overloaded_args和overloaded_types中。
    只考虑不重复的类型。如果一个类型是另一个类型的子类,它将具有更高的优先级,否则优先级顺序与relevant_args中参数的顺序相同,即在参数列表中从左至右的顺序。

    在本函数中实现的用于确定优先级的算法在NEP-0018_中有描述。

    在C++实现中,查看torch::append_overloaded_arg函数以获得等效功能。

    参数
    relevant_args : 可迭代的数组似对象
    需要检查__torch_function__方法的数组似对象参数的可迭代集合。

    返回
    overloaded_args : 列表
    从relevant_args中选择的参数,用于调用__torch_function__方法,按照应当调用的顺序排列。

    .. _NEP-0018:
    https://numpy.org/neps/nep-0018-array-function-protocol.html
    """
    # Runtime is O(num_arguments * num_unique_types)
    overloaded_types: Set[Type] = set()
    overloaded_args: List[Any] = []
    for arg in relevant_args:
        arg_type = type(arg)
        # We only collect arguments if they have a unique type, which ensures
        # reasonable performance even with a long list of possibly overloaded
        # arguments.
        if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__')):
            # Create lists explicitly for the first type (usually the only one
            # done) to avoid setting up the iterator for overloaded_args.
            if overloaded_types:
                overloaded_types.add(arg_type)
                # By default, insert argument at the end, but if it is
                # subclass of another argument, insert it before that argument.
                # This ensures "subclasses before superclasses".
                index = len(overloaded_args)
                for i, old_arg in enumerate(overloaded_args):
                    if issubclass(arg_type, type(old_arg)):
                        index = i
                        break
                overloaded_args.insert(index, arg)
            else:
                overloaded_types = {arg_type}
                overloaded_args = [arg]
    return overloaded_args

_get_overloaded_args函数在PyTorch中扮演着关键角色,用于从一系列参数中筛选出那些定义了__torch_function__方法的参数,并根据特定的优先级规则对它们进行排序。这个函数对于实现动态类型操作和确保正确的操作顺序至关重要。下面是对这个函数的详细解释:

函数定义和参数

def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:

relevant_args: 这是一个可迭代的对象,包含了需要检查__torch_function__方法的所有参数。这些参数可以是任何类型,包括但不限于Tensor和其他自定义数据类型。
函数逻辑
初始化:函数开始时,初始化两个变量:
overloaded_types: 一个set,用于存储已发现的、具有__torch_function__属性的不重复类型。
overloaded_args: 一个list,用于存储筛选出的、具有__torch_function__属性的参数。
参数检查和收集:接下来,函数遍历relevant_args中的每一个参数arg。对于每一个参数,函数首先获取其类型arg_type
类型和属性检查:如果arg_type尚未存在于overloaded_types中,并且它具有__torch_function__属性,那么它将被添加到overloaded_types中,并根据特定的规则插入到overloaded_args列表中。
插入规则:如果overloaded_types已经包含至少一个类型,那么arg_type将根据以下规则插入到overloaded_args中:
默认情况下,参数将被添加到列表的末尾。
然而,如果arg_type是列表中现有参数类型的一个子类,那么它将被插入到对应的父类参数之前。这种插入规则确保了“子类优先于超类”的原则,这是NEP-0018规范所推荐的。
初次收集:如果overloaded_types是空的,这意味着这是第一次发现一个具有__torch_function__属性的参数类型,因此arg_typearg将直接被添加到overloaded_typesoverloaded_args中。
返回结果:函数最后返回overloaded_args列表,该列表包含了所有具有__torch_function__属性的参数,并根据上述规则进行了排序。
总结
_get_overloaded_args函数通过检查参数类型__torch_function__属性,有效地筛选和排序了可以进行动态类型操作的参数。它遵循了NEP-0018规范中的建议,确保了“子类优先于超类”的原则,这对于正确处理继承关系中的类型操作至关重要。通过这个函数,PyTorch能够支持更加灵活和强大的自定义数据类型和操作,同时保持良好的性能和操作顺序。

回头再看embedding源码:

if has_torch_function_variadic(input, weight):
    return handle_torch_function(
        embedding, (input, weight),
        input, weight, padding_idx, max_norm, norm_type,
        scale_grad_by_freq, sparse
    )

主要检查了两个参数input和weight是否支持__torch_function__协议,并根据检查结果决定调用自定义的函数实现还是默认的embedding函数。这里是详细的步骤和解释:

检查__torch_function__支持:has_torch_function_variadic函数用于检测传入的参数inputweight是否定义了__torch_function__方法。这个方法是PyTorch中用于实现操作重载和自定义操作的关键机制,允许用户定义的数据类型响应特定的函数调用。
条件判断:如果inputweight中任何一个支持__torch_function__,则条件成立,代码将执行return语句中的内容。这意味着将尝试调用一个由用户自定义的embedding函数实现,而不是默认的embedding函数。
调用handle_torch_function:如果条件成立,handle_torch_function函数将被调用。这个函数的作用是处理__torch_function__的调用,它会尝试查找并调用inputweight中定义的自定义embedding实现。handle_torch_function函数接受多个参数,包括:
embedding:这是原始的embedding函数,如果找不到自定义实现,将调用这个函数。
(input, weight):这是包含所有相关参数的元组,用于handle_torch_function检查__torch_function__的存在。
input、weight、padding_idx、max_norm、norm_type、scale_grad_by_freq和sparse:这些都是embedding函数的参数,会被传递给自定义的embedding实现或默认的embedding函数。
返回结果:handle_torch_function函数会返回自定义embedding实现的结果,如果存在的话;否则,它将调用默认的embedding函数并返回其结果。
总之,这段代码体现了PyTorch中__torch_function__协议的强大功能,它允许用户在不影响现有API的情况下,轻松地扩展和定制函数行为。通过简单的条件检查handle_torch_function的调用,用户可以为特定的数据类型定义自定义的embedding实现,同时保持代码的简洁性和易读性。

继续看embedding源码:

    if padding_idx is not None:
        if padding_idx > 0:
            assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings"
        elif padding_idx < 0:
            assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings"
            padding_idx = weight.size(0) + padding_idx
    else:
        padding_idx = -1

主要处理了padding_idx参数,它是PyTorch中nn.Embedding类的一个重要属性,用于标识嵌入层中用于填充的特定索引。以下是代码的逐行解析和解释:

if padding_idx is not None: 这行代码检查padding_idx是否被设置。如果padding_idx没有被设置(即None),则默认不会有任何填充操作。但如果它被设置了,代码将继续执行后面的逻辑。
if padding_idx > 0: 如果padding_idx是一个正数,代码将检查这个索引是否在权重矩阵weight的有效范围内。由于weight.size(0)返回的是权重矩阵的第一维度大小(通常是词汇表的大小),所以padding_idx必须小于weight.size(0),以确保它是一个有效的索引。
assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings" 这行代码通过断言(assert语句)来确保padding_idx在有效范围内。如果padding_idx超出范围,将会抛出一个异常,提示“Padding_idx must be within num_embeddings”。
elif padding_idx < 0: 如果padding_idx是一个负数,这意味着它是指向权重矩阵末尾的一个相对索引。例如,-1表示最后一个元素,-2表示倒数第二个元素,以此类推。
assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings" 类似于正数的情况,这里也使用assert语句来确保padding_idx在有效范围内。对于负数索引,它必须大于等于-weight.size(0),以确保它指向的是权重矩阵中的一个合法元素。
padding_idx = weight.size(0) + padding_idx 如果padding_idx是负数并且在有效范围内,这行代码将把padding_idx转换成一个正数索引,使其指向权重矩阵中的正确位置。
else: 如果padding_idx既不是正数也不是负数,这意味着它最初被设置为None。在这种情况下,padding_idx将被设置为-1,这通常表示没有填充操作。
padding_idx = -1 这是else分支的一部分,将padding_idx设置为-1,指示没有填充操作。
总之,这段代码确保了padding_idx在创建nn.Embedding实例时是一个合法的索引,无论是正数、负数还是未设置(None)。它还处理了负数索引的情况,将其转换为相应的正数索引,以便在后续的嵌入操作中正确使用。

继续看embedding源码最后一部分:

    if max_norm is not None:
        # Note [embedding_renorm contiguous]
        # `embedding_renorm_` will call .contiguous() on input anyways, so we
        # call it here and take advantage of the improved locality in the
        # `embedding` call below too.
        input = input.contiguous()
        # Note [embedding_renorm set_grad_enabled]
        # XXX: equivalent to
        # with torch.no_grad():
        #   torch.embedding_renorm_
        # remove once script supports set_grad_enabled
        _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

主要处理了嵌入层(embedding)中的权重规范化(weight normalization)以及最终的嵌入层调用。下面是代码的逐行解析和整体解释:

if max_norm is not None: 这段代码首先检查max_norm参数是否被设置。max_norm用于限制嵌入向量的范数,防止它们变得过大,这有助于训练稳定性和防止梯度爆炸。
input = input.contiguous() 如果max_norm被设置,代码将确保input张量是连续存储的。这是因为后续的embedding_renorm_操作可能需要连续的内存布局以提高效率。通过提前调用.contiguous(),可以确保数据的局部性,从而可能提高后续操作的性能。
_no_grad_embedding_renorm_(weight, input, max_norm, norm_type) 这里调用了_no_grad_embedding_renorm_函数,这是一个内核级别的操作,用于在不追踪梯度的情况下重新规范化权重。这通常在训练过程中进行,以保持权重的范数不超过max_normnorm_type参数指定了范数的类型,如L1、L2等。Note [embedding_renorm set_grad_enabled]这个注释解释了为什么这里使用_no_grad_embedding_renorm_而不是普通的embedding_renorm_。原因是脚本模式(script mode)目前不支持set_grad_enabled上下文管理器。在非脚本模式下,通常会使用with torch.no_grad():来阻止梯度追踪,但脚本模式下需要使用_no_grad_embedding_renorm_这样的特定函数来达到相同的效果。
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) 最终,代码调用torch.embedding函数来执行实际的嵌入操作。weight是嵌入矩阵,input是索引张量,padding_idx用于指定填充索引,scale_grad_by_freq用于调整梯度,sparse则是一个布尔标志,表示是否使用稀疏梯度更新。
总之,这段代码实现了嵌入层的权重规范化处理,并最终调用torch.embedding来完成嵌入操作。规范化步骤确保了嵌入向量的范数不会超过设定的上限,这对于训练稳定性非常重要。此外,通过使用_no_grad_embedding_renorm_,代码在不增加额外的梯度追踪负担的情况下完成了规范化操作,这对于提高训练效率是有益的。

def _no_grad_embedding_renorm_(weight: Tensor, input: Tensor, max_norm: float, norm_type: float) -> Tensor:
    with torch.no_grad():
        torch.embedding_renorm_(weight, input, max_norm, norm_type)
def embedding_renorm_(input: Tensor, indices: Tensor, max_norm: _float, norm_type: _float) -> Tensor: ...

二. 最后解析Embedding

class Embedding(Module):
    r"""一个简单的查找表,用于存储固定词典和大小的嵌入向量。

    该模块常用来存储词嵌入,并通过索引检索它们。模块的输入是一个索引列表,而输出则是相应
    的词嵌入向量。

    参数:
        num_embeddings (int): 嵌入词典的大小
        embedding_dim (int): 每个嵌入向量的维度
        padding_idx (int, 可选): 如果指定了该参数,位于 :attr:`padding_idx` 的条目不会对梯度有贡献;
                                 因此,在训练过程中 :attr:`padding_idx` 处的嵌入向量不会被更新,
                                 即保持为一个固定的“填充”值。对于新构建的 `Embedding`,
                                 在 :attr:`padding_idx` 处的嵌入向量默认为全零,
                                 但可以更新为另一个值以作为填充向量使用。
        max_norm (float, 可选): 如果给定,所有大于 :attr:`max_norm` 范数的嵌入向量
                                将被重归一化,使其范数等于 :attr:`max_norm`。
        norm_type (float, 可选): 用于 :attr:`max_norm` 选项的 p-范数计算中的 p 值。默认值为 ``2``。
        scale_grad_by_freq (布尔值, 可选): 如果给定,这将按小批量中单词出现频率的倒数来缩放梯度。
                                           默认值为 ``False``。
        sparse (布尔值, 可选): 如果为 ``True``,相对于 :attr:`weight` 矩阵的梯度将是一个稀疏张量。
                               更多关于稀疏梯度的细节请参见注意事项。

    属性:
        weight (Tensor): 形状为 (num_embeddings, embedding_dim) 的可学习权重,
                         初始化自 :math:`\mathcal{N}(0, 1)`

    形状:
        - 输入: :math:`(*)`,任意形状的 IntTensor 或 LongTensor,包含要提取的索引
        - 输出: :math:`(*, H)`,其中 `*` 是输入的形状,:math:`H=\text{embedding\_dim}`

    .. 注意::
        请注意,目前只有有限数量的优化器支持稀疏梯度:
        目前包括 :class:`optim.SGD` (`CUDA` 和 `CPU`),
        :class:`optim.SparseAdam` (`CUDA` 和 `CPU`) 和 :class:`optim.Adagrad` (`CPU`)。

    .. 注意::
        当 :attr:`max_norm` 不为 ``None`` 时,`Embedding` 的前向方法将就地修改
        :attr:`weight` 张量。由于用于梯度计算的张量不能就地修改,在调用
        `Embedding` 的前向方法之前对 ``Embedding.weight`` 执行可微操作时,
        必须在 :attr:`max_norm` 不为 ``None`` 的情况下克隆 ``Embedding.weight``。
        例如::

            n, d, m = 3, 5, 7
            embedding = nn.Embedding(n, d, max_norm=True)
            W = torch.randn((m, d), requires_grad=True)
            idx = torch.tensor([1, 2])
            a = embedding.weight.clone() @ W.t()  # 必须克隆权重才能使此操作可微分
            b = embedding(idx) @ W.t()  # 就地修改权重
            out = (a.unsqueeze(0) + b.unsqueeze(1))
            loss = out.sigmoid().prod()
            loss.backward()

            代码解析:
                1.初始化变量:
                    n, d, m = 3, 5, 7
                    这里n代表嵌入词典的大小(即num_embeddings),d代表嵌入向量的维度(即embedding_dim),m代表另一张量W的行数。
                2.创建Embedding实例:
                    embedding = nn.Embedding(n, d, max_norm=True)
                    这行代码创建了一个Embedding对象,它将存储一个3×5的嵌入矩阵,其中max_norm=True表示所有嵌入向量将被限制在某个最大范数内。
                3.生成随机权重矩阵W:
                    W = torch.randn((m, d), requires_grad=True)
                    W是一个7×5的随机矩阵,其元素需要梯度计算。
                4.选择索引:
                    idx = torch.tensor([1, 2])
                    idx是一个包含两个整数的张量,表示我们想要从嵌入矩阵中检索出第1和第2行的向量。
                5.克隆并乘以W:
                    a = embedding.weight.clone() @ W.t()
                    这里,embedding.weight是一个3×5的张量,我们首先克隆它以避免就地修改。然后,我们用克隆后的权重矩阵乘以W的转置(W.t())。因为max_norm设置为True,直接使用embedding.weight可能导致错误,因为它在Embedding的前向传播中会被就地修改。
                6.获取嵌入向量并乘以W:
                    b = embedding(idx) @ W.t()
                    这行代码使用idx从Embedding中提取嵌入向量,这些向量随后也被乘以W的转置。这里embedding(idx)会就地修改embedding.weight,这是因为max_norm非None。
                7.计算输出并求损失:
                    out = (a.unsqueeze(0) + b.unsqueeze(1))
                    loss = out.sigmoid().prod()
                    out是a和b的外积,loss是out经过sigmoid函数后各元素的乘积。这通常是在训练模型时计算损失的一种方式。
                8.反向传播:
                    loss.backward()
                    这一步启动反向传播过程,计算损失函数相对于所有需要梯度的张量的梯度。
                关键点:
                    当max_norm非None时,embedding.weight在Embedding的前向传播中会被就地修改,这意味着在任何涉及embedding.weight的可微操作前,你都必须克隆它,以避免破坏梯度计算。
                    通过选择性地应用clone()和理解max_norm的影响,你可以正确地在你的模型中使用nn.Embedding模块,同时保证所有必要的操作都是可微的,从而可以进行有效的训练。

    Examples::

        >>> # an Embedding module containing 10 tensors of size 3
        >>> embedding = nn.Embedding(10, 3)
        >>> # a batch of 2 samples of 4 indices each
        >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
        >>> embedding(input)
        tensor([[[-0.0251, -1.6902,  0.7172],
                 [-0.6431,  0.0748,  0.6969],
                 [ 1.4970,  1.3448, -0.9685],
                 [-0.3677, -2.7265, -0.1685]],

                [[ 1.4970,  1.3448, -0.9685],
                 [ 0.4362, -0.4004,  0.9400],
                 [-0.6431,  0.0748,  0.6969],
                 [ 0.9124, -2.3616,  1.1151]]])


        >>> # example with padding_idx
        >>> embedding = nn.Embedding(10, 3, padding_idx=0)
        >>> input = torch.LongTensor([[0,2,0,5]])
        >>> embedding(input)
        tensor([[[ 0.0000,  0.0000,  0.0000],
                 [ 0.1535, -2.0309,  0.9315],
                 [ 0.0000,  0.0000,  0.0000],
                 [-0.1655,  0.9897,  0.0635]]])

        >>> # example of changing `pad` vector
        >>> padding_idx = 0
        >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
        >>> embedding.weight
        Parameter containing:
        tensor([[ 0.0000,  0.0000,  0.0000],
                [-0.7895, -0.7089, -0.0364],
                [ 0.6778,  0.5803,  0.2678]], requires_grad=True)
        >>> with torch.no_grad():
        ...     embedding.weight[padding_idx] = torch.ones(3)
        >>> embedding.weight
        Parameter containing:
        tensor([[ 1.0000,  1.0000,  1.0000],
                [-0.7895, -0.7089, -0.0364],
                [ 0.6778,  0.5803,  0.2678]], requires_grad=True)
    """
    __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
                     'norm_type', 'scale_grad_by_freq', 'sparse']

    num_embeddings: int
    embedding_dim: int
    padding_idx: Optional[int]
    max_norm: Optional[float]
    norm_type: float
    scale_grad_by_freq: bool
    weight: Tensor
    sparse: bool

    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
                 sparse: bool = False, _weight: Optional[Tensor] = None,
                 device=None, dtype=None) -> None:
        """

        接收参数:
            函数接收多个参数,包括num_embeddings(词典大小)、embedding_dim(嵌入向量的维度)、padding_idx(填充索引)、max_norm(最大范数)
            、norm_type(范数类型)、scale_grad_by_freq(是否按频率缩放梯度)、sparse(是否使用稀疏梯度)、_weight(可选的权重张量)
            、device(设备,如CPU或GPU)和dtype(数据类型)。

        """
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Embedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim

        # 检查padding_idx是否在合法范围内。
        # 如果padding_idx为负数,将其转换为正数索引,确保它在num_embeddings范围内。
        # 最后,将处理后的padding_idx赋值给当前对象的属性
        if padding_idx is not None:
            if padding_idx > 0:
                assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
            elif padding_idx < 0:
                assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
                padding_idx = self.num_embeddings + padding_idx
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq

        # 初始化权重weight
        # 如果_weight参数没有给出,那么创建一个新的张量作为权重,并调用reset_parameters()方法对其进行初始化。
        # 如果_weight提供了,那么验证其形状是否与num_embeddings和embedding_dim相匹配,然后将其作为权重张量。
        if _weight is None:
            self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
            self.reset_parameters()
        else:
            assert list(_weight.shape) == [num_embeddings, embedding_dim], \
                'Shape of weight does not match num_embeddings and embedding_dim'
            self.weight = Parameter(_weight)

        self.sparse = sparse

    def reset_parameters(self) -> None:
        """
            init.normal_():
                使用正态分布来初始化权重张量self.weight
                将self.weight张量的元素初始化为来自标准正态分布的随机数。
                这有助于打破对称性,确保网络的每一层都有不同的初始化状态,从而避免所有神经元学习相同的功能
        :return:
        """
        init.normal_(self.weight)
        self._fill_padding_idx_with_zero()

    def _fill_padding_idx_with_zero(self) -> None:
        """
        padding_idx是Embedding层中的一个特殊索引,用于标识哪些嵌入向量应该被视为“填充”数据,在训练中通常不参与梯度计算

        self.weight[self.padding_idx].fill_(0):
            self.weight[self.padding_idx].fill_(0):如果padding_idx存在,那么将self.weight张量中padding_idx对应位置的向量用0填充。
            这意味着在该索引位置上的嵌入向量将被初始化为全零向量,这通常用于表示“填充”或“空白”词汇的嵌入,确保它们在计算梯度时不产生影响
        :return:
        """
        if self.padding_idx is not None:
            with torch.no_grad():
                self.weight[self.padding_idx].fill_(0)

    def forward(self, input: Tensor) -> Tensor:
        return F.embedding(
            input, self.weight, self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse)

继续:

    def extra_repr(self) -> str:
        s = '{num_embeddings}, {embedding_dim}'
        if self.padding_idx is not None:
            s += ', padding_idx={padding_idx}'
        if self.max_norm is not None:
            s += ', max_norm={max_norm}'
        if self.norm_type != 2:
            s += ', norm_type={norm_type}'
        if self.scale_grad_by_freq is not False:
            s += ', scale_grad_by_freq={scale_grad_by_freq}'
        if self.sparse is not False:
            s += ', sparse=True'
        return s.format(**self.__dict__)

用于生成一个字符串,描述嵌入层的一些关键属性和配置。extra_repr方法在PyTorch中用于提供模型或模块的附加信息,通常在打印模型结构或调试时使用。下面是逐行解析和整体解释:

def extra_repr(self) -> str: 定义了extra_repr方法,它返回一个字符串。
s = '{num_embeddings}, {embedding_dim}' 初始化字符串s,包含了嵌入层的基本信息:num_embeddings(词汇表大小)和embedding_dim(嵌入向量的维度)。
if self.padding_idx is not None: 如果padding_idx被设置,意味着有特定的索引用于填充操作,将在s字符串中加入这部分信息。
s += ', padding_idx={padding_idx}' padding_idx的信息添加到s中。
if self.max_norm is not None: 如果max_norm被设置,表示对嵌入向量的范数有限制,将在s字符串中加入这部分信息。
s += ', max_norm={max_norm}'max_norm的信息添加到s中。
if self.norm_type != 2: 如果norm_type不等于2(默认的L2范数),意味着使用了不同的范数类型,将在s字符串中加入这部分信息。
s += ', norm_type={norm_type}' norm_type的信息添加到s中。
if self.scale_grad_by_freq is not False: 如果scale_grad_by_freq被启用,表示在反向传播时,梯度将按频率缩放,将在s字符串中加入这部分信息。
s += ', scale_grad_by_freq={scale_grad_by_freq}' scale_grad_by_freq的信息添加到s中。
if self.sparse is not False: 如果sparse被设置为True,表示使用稀疏梯度更新,将在s字符串中加入这部分信息。
s += ', sparse=True' sparse的信息添加到s中,注意这里总是显示为True,因为只有当sparseTrue时,才会添加这部分信息。
return s.format(**self.__dict__) 使用self.__dict__中的键值对替换s字符串中的格式化占位符,然后返回最终的字符串。
总结来说extra_repr方法通过构建一个描述性的字符串,提供了关于nn.Embedding实例的详细配置信息,包括词汇表大小、嵌入维度、填充索引、最大范数、范数类型、梯度按频率缩放以及是否使用稀疏更新等。这个字符串在打印模型结构或日志输出时非常有用,可以帮助理解或调试模型的具体设置。

继续最后一部分:

    def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
                        max_norm=None, norm_type=2., scale_grad_by_freq=False,
                        sparse=False):
        r"""从给定的二维FloatTensor创建Embedding实例。

        参数:
            embeddings (Tensor): 包含用于Embedding的权重的FloatTensor。第一维作为num_embeddings传递给Embedding,第二维作为embedding_dim。
            freeze (布尔值, 可选): 如果为True,则该张量在学习过程中不会更新。等价于embedding.weight.requires_grad = False。默认: True
            padding_idx (int, 可选): 如果指定,则在位置:attr:padding_idx处的条目不会对梯度有贡献;因此,在:attr:padding_idx处的嵌入向量在训练过程中不会被更新,也就是说,它保持为一个固定的"填充"不变。
            max_norm (float, 可选): 请参阅模块初始化文档。
            norm_type (float, 可选): 请参阅模块初始化文档。默认值为2。
            scale_grad_by_freq (布尔值, 可选): 请参阅模块初始化文档。默认为False。
            sparse (bool, 可选): 请参阅模块初始化文档。

        Examples::

            >>> # FloatTensor containing pretrained weights
            >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
            >>> embedding = nn.Embedding.from_pretrained(weight)
            >>> # Get embeddings for index 1
            >>> input = torch.LongTensor([1])
            >>> embedding(input)
            tensor([[ 4.0000,  5.1000,  6.3000]])
        """
        assert embeddings.dim() == 2, \
            'Embeddings parameter is expected to be 2-dimensional'
        rows, cols = embeddings.shape
        embedding = cls(
            num_embeddings=rows,
            embedding_dim=cols,
            _weight=embeddings,
            padding_idx=padding_idx,
            max_norm=max_norm,
            norm_type=norm_type,
            scale_grad_by_freq=scale_grad_by_freq,
            sparse=sparse)
        embedding.weight.requires_grad = not freeze
        return embedding

首先检查embeddings是否为二维,如果不是,则抛出异常。
然后从embeddings张量的形状中获取行数和列数,分别作为num_embeddingsembedding_dim
使用这些参数创建一个新的Embedding实例,并传递所有其他可选参数。
最后,根据freeze参数决定是否允许weight张量的梯度计算。
返回创建的Embedding实例。

完结!

  • 41
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值