一、定义
定义 案例-图-算子追踪 算子替换 新建算子并替换 子图替换 fx 模型量化 接口解读
二、实现
定义 1.1 torch.fx设计的目标就是在图上做各种变换,以完成图优化、量化等图功能性的改变。 1.2 在不改变原模型的基础上修改模型。 1.3 计算图重写 案例-图-算子追踪
import torch
from torch.fx import symbolic_trace
class MyModule( torch.nn.Module) :
def forward( self, x, y) :
return torch.add( x, y)
module = MyModule( )
symbolic_traced = symbolic_trace( module)
print( symbolic_traced.graph)
算子替换
class MyModule( torch.nn.Module) :
def forward( self, x, y) :
return torch.add( x, y)
import torch
import torch.fx as fx
def transform( m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class( ) .trace( m)
for node in graph.nodes:
if node.op == 'call_function' :
if node.target == torch.add:
node.target = torch.mul
graph.lint( )
return fx.GraphModule( m, graph)
m = MyModule( )
newModel = transform( m)
print( newModel)
新建算子并替换
import torch
from torch.fx import symbolic_trace
import operator
class M( torch.nn.Module) :
def forward( self, x, y) :
return x + y, torch.add( x, y) , x.add( y)
traced = symbolic_trace( M( ))
print( traced.graph)
patterns = set( [ operator.add, torch.add, "add" ] )
def add_2_bitwise_and( gm) :
for n in gm.graph.nodes:
if any( n.target == pattern for pattern in patterns) :
with gm.graph.inserting_after( n) :
new_node = gm.graph.call_function( torch.bitwise_and, n.args, n.kwargs)
n.replace_all_uses_with( new_node)
gm.graph.erase_node( n)
gm.recompile( )
add_2_bitwise_and( traced)
print( traced.graph)
print( traced.code)
子图替换
import torch
from torch.fx import symbolic_trace
class M( torch.nn.Module) :
def forward( self, x) :
val = torch.neg( x) + torch.relu( x)
return torch.add( val, val)
traced = symbolic_trace( M( ))
print( traced.graph)
print( traced.code)
def pattern( x) :
return torch.neg( x) + torch.relu( x)
def replacement( x) :
return torch.neg( torch.clamp( x,max= 0 ))
torch.fx.subgraph_rewriter.replace_pattern( traced, pattern, replacement)
print( traced.graph)
print( traced.code)
def comparison( x) :
val = torch.neg( torch.clamp( x, max = 0 ))
return torch.add( val, val)
comparison_fn = symbolic_trace( comparison)
x = torch.rand( 1 , 3 )
ref_output = comparison_fn( x)
test_output = traced.forward( x)
print( torch.max( ref_output-test_output))
fx 模型量化
torch.backends.quantized.engine = 'fbgemm'
qconfig_mapping = get_default_qconfig_mapping( "fbgemm" )
model_to_quantize = copy.deepcopy( model)
prepared_model = prepare_fx( model_to_quantize, qconfig_mapping, example_inputs = torch.randn( [ 1 , 3 , 224 , 224 ] ))
prepared_model.eval( )
with torch.inference_mode( ) :
for inputs, labels in test_dataloader:
prepared_model( inputs)
quantized_recover_model = convert_fx( prepared_model)
import os
import copy
import time
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision.models.resnet import resnet18
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mapping
from torch.utils.data import DataLoader
def test( model, test_dataloader, device) :
best_acc = 0
model.eval( )
test_loss = 0
correct = 0
total = 0
test_acc = 0
with torch.no_grad( ) :
for batch_idx, ( inputs, targets) in enumerate( test_dataloader) :
inputs, targets = inputs.to( device) , targets.to( device)
outputs = model( inputs)
criterion = nn.CrossEntropyLoss( )
loss = criterion( outputs, targets)
test_loss += loss.item( )
_, predicted = outputs.max( 1 )
total += targets.size( 0 )
correct += predicted.eq( targets) .sum( ) .item( )
test_acc = correct / total
print( '[INFO] Test Accurancy: {:.3f}' .format( test_acc) , '\n' )
def print_size_of_model( model) :
torch.save( model.state_dict( ) , "tmp.pt" )
print( f"The model size:{os.path.getsize('tmp.pt') / 1e6}MB" )
model = resnet18( pretrained= True)
model.conv1 = nn.Conv2d( 3 , 64 , kernel_size = ( 3 , 3 ) , stride = ( 1 , 1 ) , padding = ( 1 , 1 ) , bias = False)
model.maxpool = nn.Identity( )
model.fc = nn.Linear( model.fc.in_features, 10 )
model.eval( )
transform_test = transforms.Compose( [
transforms.ToTensor( ) ,
transforms.Normalize(( 0.4914 , 0.4822 , 0.4465 ) , ( 0.2023 , 0.1994 , 0.2010 )) ,
] )
train_data = torchvision.datasets.CIFAR10( root= 'data' , train = True, transform = torchvision.transforms.ToTensor( ) ,download= True)
test_data = torchvision.datasets.CIFAR10( root= 'data' , train = False, transform = transform_test, download = True)
print( "训练集的长度:{}" .format( len( train_data)) )
print( "测试集的长度:{}" .format( len( test_data)) )
train_dataloader = DataLoader( train_data, batch_size = 64 )
test_dataloader = DataLoader( test_data, batch_size = 64 )
torch.backends.quantized.engine = 'fbgemm'
qconfig_mapping = get_default_qconfig_mapping( "fbgemm" )
model_to_quantize = copy.deepcopy( model)
prepared_model = prepare_fx( model_to_quantize, qconfig_mapping, example_inputs = torch.randn( [ 1 , 3 , 224 , 224 ] ))
prepared_model.eval( )
with torch.inference_mode( ) :
for inputs, labels in test_dataloader:
prepared_model( inputs)
quantized_recover_model = convert_fx( prepared_model)
script_module = torch.jit.trace( quantized_recover_model, example_inputs = torch.randn( [ 1 , 3 , 224 , 224 ] ))
torch.jit.save( script_module, "quant_model.pth" )
with torch.autograd.profiler.profile( enabled= True, use_cuda = False, record_shapes = False, profile_memory = False) as prof:
test( model, test_dataloader, device = 'cpu' )
print( prof.table( ))
quantized_recover_model = torch.jit.load( "quant_model.pth" )
with torch.autograd.profiler.profile( enabled= True, use_cuda = False, record_shapes = False, profile_memory = False) as prof:
test( quantized_recover_model, test_dataloader, device = 'cpu' )
print( prof.table( ))
接口解读 fx 量化接口 参数配置