采用python实现torch.matmul


a=torch.ones(2,5,3)
b= torch.ones(1,3,4)

for i in range(2):
    for j in range(5):
        a[i,j,:]=i+j

c=torch.matmul(a,b)
print(a)
print(b)
print(c)
print(c.shape)
# pdb.set_trace()

c1=torch.zeros(2,5,4)
for i in range(2):
    for j in range(5):
        for k in range(4):
            tmp=0
            for k1 in range(3):
                tmp=tmp+a[i,j,k1]*b[0,k1,k]
            c1[i,j,k]=tmp

print(c1.shape)
print(c1)

结果:

tensor([[[0., 0., 0.],
         [1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.],
         [4., 4., 4.]],

        [[1., 1., 1.],
         [2., 2., 2.],
         [3., 3., 3.],
         [4., 4., 4.],
         [5., 5., 5.]]])
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
tensor([[[ 0.,  0.,  0.,  0.],
         [ 3.,  3.,  3.,  3.],
         [ 6.,  6.,  6.,  6.],
         [ 9.,  9.,  9.,  9.],
         [12., 12., 12., 12.]],

        [[ 3.,  3.,  3.,  3.],
         [ 6.,  6.,  6.,  6.],
         [ 9.,  9.,  9.,  9.],
         [12., 12., 12., 12.],
         [15., 15., 15., 15.]]])
torch.Size([2, 5, 4])
torch.Size([2, 5, 4])
tensor([[[ 0.,  0.,  0.,  0.],
         [ 3.,  3.,  3.,  3.],
         [ 6.,  6.,  6.,  6.],
         [ 9.,  9.,  9.,  9.],
         [12., 12., 12., 12.]],

        [[ 3.,  3.,  3.,  3.],
         [ 6.,  6.,  6.,  6.],
         [ 9.,  9.,  9.,  9.],
         [12., 12., 12., 12.],

可见,可以采用类似的方法来实现。


        batch_T = torch.matmul(self.inv_delta_C, batch_C_prime_with_zeros)  # batch_size x F+3 x 2
        t3=time.time()
        # pdb.set_trace()
        batch_P_prime = torch.matmul(self.P_hat, batch_T)  # batch_size x n x 2
        t4=time.time()
        b2=torch.zeros(1,102400, 2)
        for i  in range(102400):
            if i%10==0:
                print(i)
            for j in range(2):
                tmp=0
                for k in range(964):
                    tmp=tmp+self.P_hat[i,k]*batch_T[0,k,j]
                b2[0,i,j]=tmp
        
        for i  in range(102400):
            for j in range(2):
                if b2[0,i,j]!=batch_P_prime[0,i,j]:
                    print(i,j,b2[0,i,j]-batch_P_prime[0,i,j])
                    pdb.set_trace()

不用库,确实特别慢啊,特别慢。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值