input_ids与inputs_embeds的区别
在使用 BERT 模型时,input_ids
和 inputs_embeds
都是用于表示输入数据的,但它们有不同的用途和数据格式。以下是它们的区别和详细解释:
1. input_ids
定义:input_ids
是token的索引序列。
类型:torch.Tensor
,包含整数索引。
用途:直接将文本分词后得到的token索引作为模型的输入。
示例:
# -*- coding: utf-8 -*-
# @time: 2024/7/11 18:58
import torch
from transformers import BertTokenizer, BertModel
# 加载预训练的 BERT 模型和分词器
model_name = 'google-bert/bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
# 示例文本
text = "我爱学习"
inputs = tokenizer(text, truncation=True, padding=True, return_tensors='pt')
input_ids = inputs['input_ids']
attention_mask = inputs["attention_mask"]
token_type_ids = inputs["token_type_ids"]
print(input_ids) # torch.Size([1, 6])
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
print(outputs.last_hidden_state) # torch.Size([1, 6, 768])
输出:
tensor([[ 101, 2769, 4263, 2110, 739, 102]])
tensor([[[ 0.0707, 0.1193, -0.1171, ..., 0.8360, 0.1781, -0.3330],
[ 0.3077, 0.0733, -0.0056, ..., -0.9890, -0.0871, -0.1517],
[ 1.0312, -0.3254, -1.0671, ..., 0.1756, 0.2145, -0.1055],
[ 0.1788, -0.1532, -0.9967, ..., 0.4146, 0.1664, -0.4200],
[ 0.8302, -0.5292, -0.7375, ..., -0.1464, 0.2384, 0.2083],
[ 0.3939, 0.1830, -0.2468, ..., 0.8480, 0.1541, 0.1683]]])
2. inputs_embeds
定义:inputs_embeds
是token的嵌入表示。
类型:torch.Tensor
,包含浮点数的向量。
用途:当已经有了token的嵌入表示时,可以直接将这些嵌入作为模型的输入,而不是使用 input_ids
由模型内部的嵌入层进行转换。
示例:
# -*- coding: utf-8 -*-
# @time: 2024/7/11 18:58
import torch
from transformers import BertTokenizer, BertModel
# 加载预训练的 BERT 模型和分词器
model_name = 'google-bert/bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
# 示例文本
text = "我爱学习"
inputs_embeds = torch.randn(1, 6, 768) # 假设inputs_embeds是"我爱学习"的token嵌入表示: [batch_size, seq_length, embedding_dim]
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0]])
print(inputs_embeds) # torch.Size([1, 6, 768])
# 获取输入的最终嵌入表示
with torch.no_grad():
outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, token_type_ids=token_type_ids)
print(outputs.last_hidden_state) # torch.Size([1, 6, 768])
输出:
tensor([[[-0.7986, -0.0758, -1.5470, ..., -1.0122, 1.0138, 0.1089],
[-0.8512, -0.2396, -0.1447, ..., 1.5843, 1.1964, -0.2520],
[-1.1290, 0.5671, -0.7088, ..., 0.1192, -0.2578, -2.7265],
[-0.4074, -0.7457, 0.3901, ..., 0.1135, 1.2999, 1.0811],
[-2.5027, 0.3060, 2.4674, ..., 0.8905, -0.4032, -0.6385],
[-1.3476, -1.0436, 2.7078, ..., -0.8193, 0.5971, 0.2938]]])
tensor([[[-1.0209, -0.2582, -0.7404, ..., 0.0584, 0.0331, -0.1135],
[-0.9100, 0.4910, 0.1831, ..., 0.7255, -0.0487, -0.0845],
[-0.3630, -0.3341, -0.5843, ..., 0.2195, 0.2563, 0.3027],
[-0.5431, -0.4651, -0.4613, ..., 0.3890, -0.6342, 0.0725],
[-0.1561, -0.2302, 0.0517, ..., -0.5924, -0.8542, 0.1668],
[-0.4830, -0.8926, 0.2385, ..., 0.5087, 0.1083, 0.0618]]])
3. 总结
input_ids
是token的索引序列,用于标准的文本输入。inputs_embeds
是token的嵌入表示,用于自定义嵌入或高级处理场景。- 两者都可以作为 BERT 模型的输入,但不能同时使用。