defmain():print('cuda device count: ', torch.cuda.device_count())
device ='cuda:0'
net = torch.load('retinaface.pth')
net = net.to(device)
net.eval()print('model: ', net)#print('state dict: ', net.state_dict().keys())
tmp = torch.ones(1,3,384,640).to(device)print('input: ', tmp)
out = net(tmp)print('output:', out)if os.path.exists('retinaface.wts'):return
f =open("retinaface.wts",'w')
f.write("{}\n".format(len(net.state_dict().keys())))for k,v in net.state_dict().items():print('key: ', k)print('value: ', v.shape)
vr = v.reshape(-1).cpu().numpy()
f.write("{} {}".format(k,len(vr)))for vv in vr:
f.write(" ")
f.write(struct.pack(">f",float(vv)).hex())
f.write("\n")
1.2 wts文件转为engine文件
1.2.1 输出流文件.engine 将wts内容序列化后写入(核心函数APIToModel)
IHostMemory* modelStream{nullptr};APIToModel(BATCH_SIZE,&modelStream); ★★★★★
assert(modelStream !=nullptr);
std::ofstream p("arcface-mobilefacenet.engine", std::ios::binary);if(!p){
std::cerr <<"could not open plan output file"<< std::endl;return-1;}
p.write(reinterpret_cast<constchar*>(modelStream->data()), modelStream->size());
modelStream->destroy();
1.2.2 APIToModel(核心函数createEngine)
voidAPIToModel(unsignedint maxBatchSize, IHostMemory** modelStream){// Create builder
IBuilder* builder =createInferBuilder(gLogger);
IBuilderConfig* config = builder->createBuilderConfig();// Create model to populate the network, then set the outputs and create an engine
ICudaEngine* engine =createEngine(maxBatchSize, builder, config, DataType::kFLOAT);★★★★★
assert(engine !=nullptr);// Serialize the engine(*modelStream)= engine->serialize();// Close everything down
engine->destroy();
builder->destroy();}