pytorch reshape view性能对比 (以及einsum, matmul)

reshape和view

reshape 和 view的具体功能和区别就不介绍了,可以直接查看官网,简单来说,就是reshape会改变各个维度在存储中的物理位置,而view的话,只会改变索引。

那么我们在使用的时候,到底是选哪个呢?

先说结论:差不多~

einsum和matmul

具体用法参见官网,当我们想在维度比较高的tensor上做复杂的矩阵乘法的时候,往往会选择用einsum,因为比较清晰简单,但是如果有办法用matmul的时候,是不是会犹豫两个的性能呢?

结论:在有matmul有broadcast的情况下,einsum更快。简单情况没有测试,但是简单情况直接用matmul比较方便。

如果用for循环来实现einsum可以实现的复杂功能的话,会慢很多,所以千万不要用for loop!!!!

代码

import torch
import time
from collections import defaultdict


def generate_tensor(shape, num=1):
    return [torch.rand(shape).cuda() for _ in range(num)]


def bench_view(datas, shape_to, repeat=1):
    out = []
    for _ in range(repeat):
        res = []
        for d in datas:
            tmp = d.view(shape_to)
            res.append(tmp)
        out.append(torch.cat(res, dim=0)[:2].shape)
    print(len(out))
    
    
def bench_reshape(datas, shape_to, repeat=1):
    out = []
    for _ in range(repeat):
        res = []
        for d in datas:
            tmp = d.reshape(shape_to)
            res.append(tmp)
        out.append(torch.cat(res, dim=0)[:2].shape)
    print(len(out))
    
def bench_einsum(data1, data2, repeat=1):
    out = []
    for _ in range(repeat):
        res = []
        for i in range(len(data1)):
            tmp = torch.einsum('btnf, kfc -> btknc', data1[i], data2[i])
            res.append(tmp)
        out.append(torch.cat(res, dim=0)[:2].shape)
    print(len(out))


def bench_matmul(data1, data2, repeat=1):
    out = []
    for _ in range(repeat):
        res = []
        for i in range(len(data1)):
            # NOTE: if use for loop the speed is too low.
            # tmp = []
            # for j in range(data1[i].shape[1]):
            #     left = data1[i][:,j,...].unsqueeze(dim=1)
            #     tmp.append(torch.matmul(left, data2[i]))
            # tmp = torch.stack(tmp, dim=1)
            
            tmp = torch.matmul(data1[i], data2[i])
            res.append(tmp)
        out.append(torch.cat(res, dim=0)[:2].shape)
    print(len(out), out[0])
    
    

class StopWatch:
    def __init__(self) -> None:
        self.times = defaultdict(list)
    
    def tk(self, key='default'):
        tkn = time.time()
        if key in self.times:
            print(f'key={key}, elapsed:', tkn - self.times[key][-1])
        
        self.times[key].append(tkn)


def test_view_and_reshape():
    
    datas = generate_tensor(shape=(128, 300,30,10), num=10)
    sw = StopWatch()
    
    sw.tk()
    bench_view(datas, shape_to=(100, 64, 60, 30), repeat=1000)
    # bench_reshape(datas, shape_to=(100, 64, 60, 30), repeat=1000)
    sw.tk()
    
    
def test_einsum_matmul():
    data1 = generate_tensor(shape=(128, 1, 300,30), num=5)
    data2 = generate_tensor(shape=(20, 30, 30), num=5)
    sw = StopWatch()
    
    sw.tk()
    bench_einsum(data1, data2, repeat=1000)
    # bench_matmul(data1, data2, repeat=1000)
    sw.tk()
    
if __name__ == '__main__':
    # test_view_and_reshape()
    test_einsum_matmul()
    
    
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值