import torch
import struct
from utils.torch_utils import select_device
def pt2weight(weights):
model = torch.load(weights, map_location="cpu")['model'].eval().float()
with open(weights.split(".")[0]+".weights",'w') as f:
for k,v in model.state_dict().items():
if k == "model.24.anchors" or k == "model.24.anchor_grid":
anchor_grid = v
list_data = list(v.reshape(-1).numpy())
f.write('{}'.format(k))
for value in list_data:
f.write(" ")
f.write(struct.pack('>f',float(value)).hex())
f.write('\n')
return anchor_grid
if __name__ == '__main__':
anchor_grid = pt2weight('yolov5s.pt')
print(anchor_grid)
new_anchor = anchor_grid.squeeze(dim=1).squeeze(dim=2).squeeze(dim=2)
# print(new_anchor.view(-1))
data = new_anchor.view(-1).tolist()
print(data)