PyTorch 中contiguous函数使用详解和代码演示

在 PyTorch 中,contiguous() 是一个用于 张量内存布局优化 的函数。它的作用是在需要时返回一个内存布局为连续(contiguous)的张量,常用于 transpose、permute 等操作后。


一、为什么需要 contiguous()

PyTorch 的张量是以 行优先(row-major)顺序 存储的。当你对张量使用 transpose()permute() 等操作时,虽然张量的维度看起来改变了,但底层的内存并没有重新排列,只是修改了索引方式。

一些 PyTorch 函数(如 .view())要求输入张量必须是 连续的内存块,否则就会报错。


二、函数定义与用法

Tensor.contiguous(memory_format=torch.contiguous_format) → Tensor

返回值:

返回一个与当前张量具有相同数据但在内存中 连续排列 的副本。如果当前张量已经是连续的,就直接返回自身。


三、典型使用场景

1. view() 前需要 .contiguous()

x = torch.randn(2, 3, 4)
y = x.permute(1, 0, 2)  # 改变维度顺序
z = y.contiguous().view(3, 8)  # 安全 reshape

如果不加 .contiguous()

z = y.view(3, 8)  # ⚠️ 报错:RuntimeError: view size is not compatible with input tensor's size and stride

2. 使用 transpose() 后需要 .contiguous() 参与后续操作

a = torch.randn(10, 20)
b = a.transpose(0, 1)  # Not contiguous now
b = b.contiguous()     # 重新在内存中复制数据为连续块

四、查看是否是连续的

x.is_contiguous()

五、底层原理简要

PyTorch 张量有 .stride() 属性定义每一维的跳步。连续的张量满足:

x.stride()[i] = product(x.shape[i+1:])

一旦 .transpose() / .permute() 修改了维度顺序,这个规则就被破坏,因此 .contiguous() 会重新分配内存来确保是连续的。


六、contiguous()项目演示

下面是一个完整的 PyTorch 小项目,演示 .contiguous() 的必要性与作用。将看到在对张量进行 permute() 后,使用 .view() reshape 会失败,只有 .contiguous() 可以解决问题。

项目内容:张量维度变换与 .contiguous() 对比演示

项目结构:

contiguous_demo/
├── main.py
└── requirements.txt

requirements.txt

torch>=2.0

main.py

import torch

def describe_tensor(tensor, name):
    print(f"{name}: shape={tensor.shape}, strides={tensor.stride()}, is_contiguous={tensor.is_contiguous()}")

def main():
    print("=== 创建张量 ===")
    x = torch.randn(2, 3, 4)  # 原始张量 shape [2, 3, 4]
    describe_tensor(x, "x")

    print("\n=== 进行 permute 操作(交换维度) ===")
    y = x.permute(1, 0, 2)  # shape: [3, 2, 4]
    describe_tensor(y, "y (after permute)")

    print("\n尝试 view reshape 到 [3, 8](不使用 contiguous)")
    try:
        z = y.view(3, 8)  # ⚠️ 报错:因为 y 的内存不是连续的
    except RuntimeError as e:
        print(f"RuntimeError: {e}")

    print("\n=== 使用 .contiguous() 后 reshape ===")
    y_contig = y.contiguous()
    describe_tensor(y_contig, "y_contig (after .contiguous())")

    z = y_contig.view(3, 8)
    describe_tensor(z, "z (reshaped)")

    print("\n✅ reshape 成功,结果如下:")
    print(z)

if __name__ == "__main__":
    main()

运行方法

  1. 安装依赖:
pip install -r requirements.txt
  1. 运行程序:
python main.py

运行结果概览

将看到:

  • 原始张量是连续的;
  • permute() 后变成非连续;
  • 使用 .view() 报错;
  • .contiguous() 修复内存后成功 reshape。

小总结要点

操作是否连续能否 .view()
原始张量✅ 是✅ 是
permute()❌ 否❌ 报错
.contiguous()✅ 是✅ 是

总结记忆:

操作是否影响连续性?是否需要 .contiguous()
view()❗ 需要连续✅ 是
permute() / transpose()破坏连续性✅ 是
reshape()自动处理❌ 不需要(内部处理)
<think>嗯,用户让我详细介绍一下Transformer模型的代码。首先,我需要确认他们对Transformer的理解程度。可能他们已经有了一些理论基础,现在想通过代码来加深理解。我得先从模型的结构开始,分模块讲解,比如自注意力机制、多头注意力、位置编码等等。 用户可能希望看到具体的代码示例,而不仅仅是理论。所以,我需要用PyTorch这样的常见框架来展示,这样更容易上手。同时,要注意代码的注释,解释每一步的作用,比如Q、K、V的计算,缩放点积注意力,还有多头注意力的实现方式。 还要考虑用户可能对某些细节不太清楚,比如位置编码的实现,或者前馈网络的结构。这部分需要详细说明,比如使用正弦函数余弦函数来生成位置信息,以及如何将位置编码加到输入向量上。 另外,用户可能想知道各个模块是如何组合在一起的,比如编码器解码器的堆叠,掩码的使用等等。这里需要展示完整的Transformer类,并解释每个参数的作用,比如词嵌入维度、头数、层数等。 在写代码的时候,要注意遵循系统级指令,比如数学公式用$...$,独立公式用$$...$$。例如,在解释缩放点积注意力时,公式要用LaTeX正确表示。同时,结构要清晰,分步骤讲解,确保用户能一步步理解。 可能还需要指出一些常见的实现细节,比如为什么使用Layer Normalization,残差连接的作用,以及训练推理时解码器的不同处理(掩码的使用)。这些点对于正确实现Transformer模型非常重要,用户可能会在这里遇到问题。 最后,要提醒用户在实际应用中的注意事项,比如参数调整、数据预处理、训练技巧等。虽然用户主要问的是代码详解,但提供一些实用建议会更有帮助。同时,要确保代码示例正确,没有语法错误,用户可以复制粘贴后进行测试。</think>我将分步骤为您详解Transformer模型的核心代码实现(基于PyTorch框架)。让我们从关键组件到完整架构逐步解析: ```python import torch import torch.nn as nn import math # ====================== # 核心组件实现 # ====================== # 1. 缩放点积注意力(Scaled Dot-Product Attention) def scaled_dot_product_attention(Q, K, V, mask=None): """ 计算:$Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$ """ d_k = Q.size(-1) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attention = torch.softmax(scores, dim=-1) return torch.matmul(attention, V) # 2. 多头注意力机制 class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() assert d_model % num_heads == 0 self.d_model = d_model # 模型维度,如512 self.num_heads = num_heads # 头数,如8 self.d_k = d_model // num_heads # 线性变换矩阵 self.W_Q = nn.Linear(d_model, d_model) self.W_K = nn.Linear(d_model, d_model) self.W_V = nn.Linear(d_model, d_model) self.W_O = nn.Linear(d_model, d_model) def split_heads(self, x): """ 将输入张量重塑为:(batch_size, num_heads, seq_len, d_k) """ batch_size = x.size(0) return x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) def forward(self, Q, K, V, mask=None): # 线性变换 + 分头 Q = self.split_heads(self.W_Q(Q)) # (B, h, L, d_k) K = self.split_heads(self.W_K(K)) V = self.split_heads(self.W_V(V)) # 计算注意力 attention = scaled_dot_product_attention(Q, K, V, mask) # 拼接多头结果 attention = attention.transpose(1, 2).contiguous() # (B, L, h, d_k) original_size = (attention.size(0), -1, self.d_model) output = self.W_O(attention.view(*original_size)) return output # 3. 位置前馈网络 class PositionWiseFFN(nn.Module): def __init__(self, d_model, d_ff=2048): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.relu = nn.ReLU() def forward(self, x): return self.linear2(self.relu(self.linear1(x))) # 4. 位置编码 class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=512): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0)/d_model)) pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置 pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置 self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)] # 加入位置编码 # ====================== # Transformer完整架构 # ====================== class TransformerEncoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.ffn = PositionWiseFFN(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): # 自注意力 + 残差连接 attn_output = self.self_attn(x, x, x, mask) x = self.norm1(x + self.dropout(attn_output)) # 前馈网络 + 残差连接 ffn_output = self.ffn(x) x = self.norm2(x + self.dropout(ffn_output)) return x class TransformerDecoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.cross_attn = MultiHeadAttention(d_model, num_heads) self.ffn = PositionWiseFFN(d_model, d_ff) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, encoder_output, src_mask, tgt_mask): # 自注意力(解码器端) attn_output = self.self_attn(x, x, x, tgt_mask) x = self.norm1(x + self.dropout(attn_output)) # 交叉注意力(编码器-解码器) attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask) x = self.norm2(x + self.dropout(attn_output)) # 前馈网络 ffn_output = self.ffn(x) x = self.norm3(x + self.dropout(ffn_output)) return x class Transformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, dropout=0.1): super().__init__() # 编码器部分 self.encoder_embedding = nn.Embedding(src_vocab_size, d_model) self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model) self.encoder_layers = nn.ModuleList( [TransformerEncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]) self.decoder_layers = nn.ModuleList( [TransformerDecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]) self.final_linear = nn.Linear(d_model, tgt_vocab_size) self.dropout = nn.Dropout(dropout) def encode(self, src, src_mask): x = self.encoder_embedding(src) x = self.pos_encoding(x) for layer in self.encoder_layers: x = layer(x, src_mask) return x def decode(self, tgt, encoder_output, src_mask, tgt_mask): x = self.decoder_embedding(tgt) x = self.pos_encoding(x) for layer in self.decoder_layers: x = layer(x, encoder_output, src_mask, tgt_mask) return x def forward(self, src, tgt, src_mask, tgt_mask): encoder_output = self.encode(src, src_mask) decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask) return self.final_linear(decoder_output) ``` 关键实现细节说明: 1. **缩放点积注意力**: - 使用$\sqrt{d_k}$进行缩放,防止softmax梯度消失 - 通过mask实现padding masksequence mask 2. **多头注意力**: - 将输入拆分为多个头并行计算 - 最后通过线性层合并结果 - 计算复杂度:$O(L^2 \cdot d)$ 3. **位置编码**: - 使用正弦/余弦函数的固定编码方案 - 公式:$$PE_{(pos,2i)} = \sin(pos/10000^{2i/d_{model}})$$ $$PE_{(pos,2i+1)} = \cos(pos/10000^{2i/d_{model}})$$ 4. **前馈网络**: - 包含两个线性变换ReLU激活 - 中间维度通常设为4倍模型维度(如512→2048) 5. **残差连接与层归一化**: - 每个子层输出:LayerNorm(x + Sublayer(x)) - 帮助缓解梯度消失问题 使用示例: ```python # 参数设置 d_model = 512 num_heads = 8 num_layers = 6 src_vocab = 10000 tgt_vocab = 8000 model = Transformer(src_vocab, tgt_vocab, d_model, num_heads, num_layers) # 假设输入 src = torch.randint(0, src_vocab, (32, 100)) # batch_size=32, seq_len=100 tgt = torch.randint(0, tgt_vocab, (32, 90)) src_mask = torch.ones(32, 1, 100) # 假设没有padding tgt_mask = torch.tril(torch.ones(90, 90)) # 解码器掩码 output = model(src, tgt, src_mask, tgt_mask) print(output.shape) # torch.Size([32, 90, 8000]) ``` 实际应用注意事项: 1. 需要根据任务调整超参数(层数、头数、维度等) 2. 训练时需处理不同长度的序列(padding mask) 3. 解码时使用自回归生成(teacher forcing或beam search) 4. 建议使用学习率warmup策略 5. 对于大数据集,推荐使用混合精度训练 此实现完整展示了Transformer的核心架构,可在此基础上进行修改以适应不同任务(如BERT、GPT等变体)。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

点云SLAM

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值