目录
python 导出onnx:
weights = 'xxxx.pth'
model_cls= Model_net(model_path=weights, num_classes=2)
state_dict = torch.load(weights)
model_cls.load_state_dict(state_dict)
model_cls.eval()
dummy_input1 = torch.randn(1, 3, 64, 64)
onnx_name = "smoke_skipnet.onnx"
torch.onnx.export(model_cls, dummy_input1, onnx_name, opset_version=11, verbose=False, input_names=['images'], output_names=['output'])
python预测例子:
class Model_net():
def __init__(self,model_path,num_classes=6):
self.model = skipnet(num_classes=num_classes)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.num_classes=num_cla