在模型优化的过程中必不可少的就是算子融合了,最近在一个项目中使用了resnet其中使用很多的bn,这就必须要优化了,在优化之后推理时间提高了百分之四十。
这部分代码是自己在网上查找资料写的,在查找的时候又在偶然中发现在pytorch1.10中有一个fx的库是可以很好处理这方面的东西的。后续我会对这个进行补充。
话不多说直接上代码:
import os
import torch
import torch.nn as nn
from build_net import make_fuse_model
from utils import DummyModule
def fuse(conv, bn):
w = conv.weight
mean = bn.running_mean
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
beta = bn.weight
gamma = bn.bias
if conv.bias is not None:
b = conv.bias
else:
b = mean.new_zeros(mean.shape)
w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
b = (b - mean) / var_sqrt * beta + gamma
fused_conv = nn.Conv2d(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
conv.dilation,
conv.groups,
bias=True,
padding_mode=conv.padding_mode
)
fused_conv.weight = nn.Parameter(w)
fused_conv.bias = nn.Parameter(b)
return fused_conv
def fuse_module(m):
children = list(m.named_children())
conv = None
conv_name = None
for name, child in children:
if isinstance(child, nn.BatchNorm2d) and conv:
bc = fuse(conv, child)
m._modules[conv_name] = bc
m._modules[name] = DummyModule()
conv = None
elif isinstance(child, nn.Conv2d):
conv = child
conv_name = name
else:
fuse_module(child)
def validate(net, input_, fuse_model_path):
net.eval()
input_ = input_
a = net(input_)
fuse_module(net)
print(fuse_model_path)
torch.save(net, fuse_model_path)
b = net(input_)
return (a - b).abs().max().item()
if __name__ == '__main__':
origin_model_type = 'simplate'
model_path = './model/' + origin_model_type + '.pth'
fuse_model_path = './fuse_model/' + origin_model_type + '.pth'
os.makedirs('./fuse_model/', exist_ok=True)
model = make_fuse_model()
model.load_state_dict(torch.load(model_path))
model.eval()
print(validate(model, torch.randn(64, 3, 80, 80), fuse_model_path))