pytorch学习笔记----手写conv2d考虑batchsize and channel

参考b站up主  deep_thoughts

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

def ManualConv2d_MatrixCaculation_batchsize_channel(input , kernel , bias = 0 , stride = 1 , padding = 0 ):
     #input  kernel size is [ batch size  , channel size  , wide  , high ]
     if(padding > 0):
          input = F.pad(input , (padding ,padding ,padding ,padding , 0 , 0 ,0,0))
     bs , inchannel , input_h , input_w = input.shape
     out_channel , inchannel , kernel_h , kernel_w = kernel.shape
     if bias is None :
          bias = torch.zeros (out_channel)   

     output_w  = (math.floor((input_h - kernel_h)/stride) + 1) # conv output wide 
     output_h  = (math.floor((input_w - kernel_w)/stride) + 1) # conv output height
     output = torch.zeros(bs,out_channel,output_h , output_w) # initial output matrix 
     
     for ind in range(bs):
        for  oc in range(out_channel):
            for  ic in range(inchannel):
               for i in range(0  , input_h - kernel_h + 1 , stride): # 对高度维进行遍历
                    for j in range(0 , input_w - kernel_w + 1 , stride): # 对宽度维进行遍历
                        region = input[ind , ic , i : i + kernel_h , j : j + kernel_w] # 
                        output[ ind , oc , int(i/stride) , int(j/stride)] += torch.sum(region * kernel[oc , ic])            
            output[ind , oc ] += bias[oc]
     return output 



input = torch.randn(2 ,2, 5 ,5)
kernnel = torch.randn(3, 2, 3, 3)
bias = torch.randn(3)


pytorch_api_conv2d = F.conv2d(input, kernnel , bias=bias , padding = 1 ,stride= 2)
mm_pytorch_full_output = ManualConv2d_MatrixCaculation_batchsize_channel(input , kernnel  , bias= bias , padding= 1 ,stride= 2)


flag = torch.allclose(pytorch_api_conv2d , mm_pytorch_full_output)
print(flag)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值