1.下载项目
git lfs install
git clone https://www.modelscope.cn/damo/cv_tinynas_object-detection_damoyolo_safety-helmet.git
2.测试模型
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
model_id = '/workspace_wjr/develop/projects/cv_tinynas_object-detection_damoyolo_safety-helmet'
input_location = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_safetyhat.jpg'
safety_hat_detection = pipeline(Tasks.domain_specific_object_detection, model=model_id)
result = safety_hat_detection(input_location)
print("result is : ", result)
3.转模型
from modelscope.models import Model
import torch
model_id = '/workspace_wjr/develop/projects/cv_tinynas_object-detection_damoyolo_safety-helmet'
model = Model.from_pretrained(model_id)
model.eval()
fake_input = torch.randn((1,3,640,640)).float()
# y = model.forward(fake_input)
# y = model(fake_input)
type(model).__call__ = type(model).forward
torch.onnx.export(model, fake_input,'helmet.onnx',opset_version=13)
Note:
1.加eval(),默认情况下,模型为train,在执行forward时直接pass,不做任务操作;
2.修改默认__call__函数,call函数包含了forward后postprocess,在转onnx时会报错。