使用的模型是比较简单的Unet,没有用pytorch那些复杂的函数
首先要把pytorch保存的pt模型转成onnx模型:
def save_onnx_illu():
input_name = ['input']
output_name = ['output']
# 加载两个模型
models_sdf = models.__dict__["KD_teacher"](input_channels=6)
models_sdf = models_sdf.cuda()
args = parse_args()
with open('models_illu/%s/config.yml' % args.testname, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
print('models_illu/%s/model.pth'%config['testname'])
models_sdf.load_state_dict(torch.load('models_illu/%s/model.pth' %
config['testname']))
models_sdf.eval()
input = Variable(torch.randn(1, 6, 1024, 1024)).cuda()
torch.onnx.export(models_sdf, input, 'static_model/models_illu.onnx', input_names=input_name, output_names=output_name, verbose=False)