rearrange 和 einsum 真的优雅吗

本文探讨了使用PyTorch实现的简单QKV模块,通过对比使用Einsum和不使用的方法,展示了它们在代码量和ONNX转换后的相似性。实验结果显示,两种实现方式输出一致,且优化后的ONNX模型表现相近。
摘要由CSDN通过智能技术生成

结论是,还好吧。

从代码量看,差不多:

# -*- coding:utf-8 -*-
import torch
from torch import nn
from torch import einsum
from einops import rearrange

class SimpleQKV(nn.Module):
    def __init__(self, dim, use_ein):
        super().__init__()
        self.proj = nn.Linear(dim, dim*3, bias=False)
        self.dim = dim
        self.scale = self.dim ** -0.5
        self.use_ein = use_ein
        torch.manual_seed(777) # 为了使权重相同,便于比较输出
        nn.init.xavier_uniform_(self.proj.weight)        

    def forward(self, x):
        n,c,h,w = x.shape
        #assert c==self.dim
        if (self.use_ein):
            x = rearrange(x, 'n c h w -> n (h w) c')
        else:
            x = x.permute(0,2,3,1).view(n, -1, c)
        qkv = self.proj(x)        
        q,k,v = qkv.chunk(chunks=3,dim=-1)        
        if (self.use_ein):
            attn = (einsum('n i c, n j c -> n i j', q, k) * self.scale).softmax(dim=-1)        
            v = einsum('n i j, n j c -> n i c', attn, v)
            output = rearrange(v, 'n (h w) c -> n c h w', h=h)
        else:            
            attn = (torch.matmul(q, k.transpose(1,2)) * self.scale).softmax(dim=-1)
            v = torch.matmul(attn, v)
            output = v.permute(0,2,1).view(n,c,h,w)
        
        return output

batch, chan, height, width = 1, 20, 7, 7
simple_qkv_ein = SimpleQKV(chan, True)
simple_qkv_noein = SimpleQKV(chan, False)

x = torch.randn(batch, chan, height, width, device='cpu')
out1 = simple_qkv_ein(x)
out2 = simple_qkv_noein(x)
assert(out1.equal(out2))

# 保存onnx
simple_qkv_ein.eval()
onnx_filename = './simple_qkv_ein.onnx'
torch.onnx.export(simple_qkv_ein, x, onnx_filename,
                  input_names=['input'], output_names=['ouput'],
                  export_params=True, verbose=False, opset_version=12)

simple_qkv_noein.eval()
onnx_filename = './simple_qkv_noein.onnx'
torch.onnx.export(simple_qkv_noein, x, onnx_filename,
                  input_names=['input'], output_names=['ouput'],
                  export_params=True, verbose=False, opset_version=12)

print('save onnx succ.')

从保存的onnx看(经过 onnxsim 优化),也差不多:

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值