输入类型是CPU,参数类型是GPU
原代码
x = torch.randn(1, 3, 608, 608)
script_models = torch.jit.trace(model, x)
script_models.save("m.jit")`
是一个使用torch.jit导出模型结构的
改为
x = torch.randn(1, 3, 608, 608)
x = x.to(device)
script_models = torch.jit.trace(model, x)
script_models.save("m.jit")
把x转换为GPU类型