搞懂GPT2张量输入输出结构:shape维度、logits切片与squeeze实战解析

本文深入讲解 GPT2 在推理过程中的张量维度结构、logits[:, -1, :] 的含义、squeeze 的使用,以及常见的字典推导式写法,帮助你在调试 Hugging Face 模型时不再迷茫。


📐 什么是 shape?

在 PyTorch 中,shape 表示一个 张量在每个维度上有多少个元素,也就是数据的“形状规格”。

示例张量shape含义
tensor(7)()标量,0 维
tensor([1, 2, 3])(3,)向量,1 维
tensor([[1, 2, 3], [4, 5, 6]])(2, 3)矩阵,2 维
tensor([[[1,2],[3,4]],[[5,6],[7,8]]])(2, 2, 2)三维张量

✉️ GPT2 的输入输出维度结构

GPT2 接受的输入通常是二维张量:

input_ids.shape = [batch_size, sequence_length]
  • batch_size:一次输入多少条句子(例:2)
  • sequence_length:每条句子被 token 化后有多少个词(例:10)

输出 logits 是三维张量:

logits.shape = [batch_size, sequence_length, vocab_size]
  • vocab_size:词表大小(GPT2 默认 50257)
  • 每个 token 会输出一个长度为 vocab_size 的分数向量

🔍 理解 logits[:, -1, :]

这是最常见的切片写法,表示:

  • ::选中所有样本(batch)
  • -1:选中每个句子的最后一个 token 的位置
  • ::该 token 对词表中每个词的打分

📌 结果 shape 为 [batch_size, vocab_size]

logits = model(input_ids).logits
next_token_logits = logits[:, -1, :]  # shape: [B, V]

用途:预测“下一个词”应该是什么。


🔄 维度压缩:squeeze()

示例 1:压缩第 0 维

x = torch.tensor([[1, 2, 3]])  # shape: [1, 3]
x_squeezed = x.squeeze(0)      # shape: [3]

示例 2:压缩失败(因为第 0 维不是 1)

x = torch.tensor([[1, 2, 3], [4, 5, 6]])  # shape: [2, 3]
x_squeezed = x.squeeze(0)                # 仍是 [2, 3]

🔁 .squeeze(dim) 只有当该维是 1 时才会生效,否则保持原样。


🧠 典型写法:字典推导式 + squeeze

在 Hugging Face 推理时,你经常会看到如下代码:

inputs = tokenizer("你好", return_tensors="pt")
inputs = {k: v.squeeze(0) for k, v in inputs.items()}

解释:

  • inputs.items():遍历分词器输出的每个键值对(如 input_ids, attention_mask)
  • .squeeze(0):去掉 batch_size=1 的维度,方便后续处理

等价写法(展开):

new_inputs = {}
for k, v in inputs.items():
    new_inputs[k] = v.squeeze(0)

🧪 logits 输出结构可视化

假设:

logits.shape = [2, 5, 50257]

可视化理解如下:

logits = [
  [ [token1_logits], ..., [token5_logits] ],  # 第1句话
  [ [token1_logits], ..., [token5_logits] ]   # 第2句话
]

再看 logits[:, -1, :]

[
  [50257 logits for1句话最后1],
  [50257 logits for2句话最后1]
]

即输出 shape 为 [2, 50257],每行是一个句子最后 token 的预测向量。


📌 总结

  • GPT2 输入是 [B, L],输出是 [B, L, V]
  • logits[:, -1, :] 是获取每句话最后一个 token 的预测输出
  • .squeeze(dim) 是压缩维度的利器,常配合字典推导式使用
  • 理解这些结构,调试模型再也不晕!

📌 YoanAILab 技术导航页

💡 项目源码 × 实战部署 × 转型经验,一页总览
👉 点击查看完整导航页

📚 包含内容:

  • 🧠 GPT-2 项目源码(GitHub)
  • ✍️ CSDN 技术专栏合集
  • 💼 知乎转型日志
  • 📖 公众号 YoanAILab 全文合集
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yoan AI Lab

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值