这是在使用 timm 的 vision_transformer.py 文件时报的错,使用的torch 版本是1.12
报错代码如下:
def forward(self, x: torch.Tensor, freqs_cis) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
q[:, :, 1:], k[:, :, 1:] = apply_rotary_emb(q[:, :, 1:], k[:, :, 1:], freqs_cis=freqs_cis)
此类故障常见的解决方案可以参考链接 Pytorch 1.7.0 RuntimeError · Issue #130 · shenweichen/DeepCTR-Torch (github.com)
但是此方法不适用于这里的代码,经过多次测试,发现问题在于q, k, v = qkv.unbind(0)
修改为 q, k, v = qkv[0], qkv[1], qkv[2] 即可以正常运行,主要是因为unbind改变了grad_fn
如下图所示: