主要就是ci和ci相乘,然后co方向加上bias。
代码如下:
import torch
def test_conv_to_matmul(n, h, w, ci, co):
conv = torch.nn.Conv2d(ci, co, (1, 1)).cuda()
input = torch.randn(n, ci, h, w).cuda()
output = conv(input)
# using matmul
output_matmul = torch.matmul(conv.weight.reshape((co, ci)), \
input.permute(0, 2, 3, 1).reshape((-1, ci)).permute((1,0))).permute(1, 0) + conv.bias
print(torch.sum(output - output_matmul.reshape(n, h, w, co).permute(0, 3, 1, 2)), flush = True)
if __name__=="__main__":
test_conv_to_matmul(10, 20, 30, 6, 8)