【HuggingFace Transformers】input_ids与inputs_embeds的区别

input_ids与inputs_embeds的区别

在使用 BERT 模型时,input_idsinputs_embeds 都是用于表示输入数据的,但它们有不同的用途和数据格式。以下是它们的区别和详细解释:

1. input_ids

定义input_idstoken的索引序列。

类型torch.Tensor,包含整数索引。

用途:直接将文本分词后得到的token索引作为模型的输入。

示例

# -*- coding: utf-8 -*-
# @time: 2024/7/11 18:58

import torch
from transformers import BertTokenizer, BertModel

# 加载预训练的 BERT 模型和分词器
model_name = 'google-bert/bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# 示例文本
text = "我爱学习"
inputs = tokenizer(text, truncation=True, padding=True, return_tensors='pt')

input_ids = inputs['input_ids']
attention_mask = inputs["attention_mask"]
token_type_ids = inputs["token_type_ids"]

print(input_ids)  # torch.Size([1, 6])

with torch.no_grad():
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

print(outputs.last_hidden_state)  # torch.Size([1, 6, 768])

输出

tensor([[ 101, 2769, 4263, 2110,  739,  102]])
tensor([[[ 0.0707,  0.1193, -0.1171,  ...,  0.8360,  0.1781, -0.3330],
         [ 0.3077,  0.0733, -0.0056,  ..., -0.9890, -0.0871, -0.1517],
         [ 1.0312, -0.3254, -1.0671,  ...,  0.1756,  0.2145, -0.1055],
         [ 0.1788, -0.1532, -0.9967,  ...,  0.4146,  0.1664, -0.4200],
         [ 0.8302, -0.5292, -0.7375,  ..., -0.1464,  0.2384,  0.2083],
         [ 0.3939,  0.1830, -0.2468,  ...,  0.8480,  0.1541,  0.1683]]])

2. inputs_embeds

定义inputs_embedstoken的嵌入表示。

类型torch.Tensor,包含浮点数的向量。

用途:当已经有了token的嵌入表示时,可以直接将这些嵌入作为模型的输入,而不是使用 input_ids 由模型内部的嵌入层进行转换。

示例

# -*- coding: utf-8 -*-
# @time: 2024/7/11 18:58

import torch
from transformers import BertTokenizer, BertModel

# 加载预训练的 BERT 模型和分词器
model_name = 'google-bert/bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# 示例文本
text = "我爱学习"
inputs_embeds = torch.randn(1, 6, 768)  # 假设inputs_embeds是"我爱学习"的token嵌入表示: [batch_size, seq_length, embedding_dim]
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0]])

print(inputs_embeds)  # torch.Size([1, 6, 768])

# 获取输入的最终嵌入表示
with torch.no_grad():
    outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, token_type_ids=token_type_ids)

print(outputs.last_hidden_state)  # torch.Size([1, 6, 768])

输出

tensor([[[-0.7986, -0.0758, -1.5470,  ..., -1.0122,  1.0138,  0.1089],
         [-0.8512, -0.2396, -0.1447,  ...,  1.5843,  1.1964, -0.2520],
         [-1.1290,  0.5671, -0.7088,  ...,  0.1192, -0.2578, -2.7265],
         [-0.4074, -0.7457,  0.3901,  ...,  0.1135,  1.2999,  1.0811],
         [-2.5027,  0.3060,  2.4674,  ...,  0.8905, -0.4032, -0.6385],
         [-1.3476, -1.0436,  2.7078,  ..., -0.8193,  0.5971,  0.2938]]])
tensor([[[-1.0209, -0.2582, -0.7404,  ...,  0.0584,  0.0331, -0.1135],
         [-0.9100,  0.4910,  0.1831,  ...,  0.7255, -0.0487, -0.0845],
         [-0.3630, -0.3341, -0.5843,  ...,  0.2195,  0.2563,  0.3027],
         [-0.5431, -0.4651, -0.4613,  ...,  0.3890, -0.6342,  0.0725],
         [-0.1561, -0.2302,  0.0517,  ..., -0.5924, -0.8542,  0.1668],
         [-0.4830, -0.8926,  0.2385,  ...,  0.5087,  0.1083,  0.0618]]])

3. 总结

  • input_idstoken的索引序列,用于标准的文本输入。
  • inputs_embedstoken的嵌入表示,用于自定义嵌入或高级处理场景。
  • 两者都可以作为 BERT 模型的输入,但不能同时使用
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CS_木成河

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值