参考b站up主 deep_thoughts
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
def ManualConv2d_MatrixCaculation_batchsize_channel(input , kernel , bias = 0 , stride = 1 , padding = 0 ):
#input kernel size is [ batch size , channel size , wide , high ]
if(padding > 0):
input = F.pad(input , (padding ,padding ,padding ,padding , 0 , 0 ,0,0))
bs , inchannel , input_h , input_w = input.shape
out_channel , inchannel , kernel_h , kernel_w = kernel.shape
if bias is None :
bias = torch.zeros (out_channel)
output_w = (math.floor((input_h - kernel_h)/stride) + 1) # conv output wide
output_h = (math.floor((input_w - kernel_w)/stride) + 1) # conv output height
output = torch.zeros(bs,out_channel,output_h , output_w) # initial output matrix
for ind in range(bs):
for oc in range(out_channel):
for ic in range(inchannel):
for i in range(0 , input_h - kernel_h + 1 , stride): # 对高度维进行遍历
for j in range(0 , input_w - kernel_w + 1 , stride): # 对宽度维进行遍历
region = input[ind , ic , i : i + kernel_h , j : j + kernel_w] #
output[ ind , oc , int(i/stride) , int(j/stride)] += torch.sum(region * kernel[oc , ic])
output[ind , oc ] += bias[oc]
return output
input = torch.randn(2 ,2, 5 ,5)
kernnel = torch.randn(3, 2, 3, 3)
bias = torch.randn(3)
pytorch_api_conv2d = F.conv2d(input, kernnel , bias=bias , padding = 1 ,stride= 2)
mm_pytorch_full_output = ManualConv2d_MatrixCaculation_batchsize_channel(input , kernnel , bias= bias , padding= 1 ,stride= 2)
flag = torch.allclose(pytorch_api_conv2d , mm_pytorch_full_output)
print(flag)