.wts
文件是深度学习领域中一种专门用于存储模型权重的文件格式,以下是更详细的介绍:
1. 基本概念
.wts
文件通常保存的是深度学习模型训练完成后的权重参数。这些权重是通过训练过程学习而来的,负责模型的预测和决策。
2. 文件内容结构
.wts
文件的具体内容通常包括:
- 参数名称:每个权重参数的名称,在代码中对应变量名。
- 参数形状:对应参数的大小和维度信息,确保在加载时能够正确重建。
- 权重值:实际的权重数值,通常以浮点数存储,有时会经过特别的编码格式(如十六进制表示)。
3. 文件格式
.wts
文件可以采用多种格式:
- 文本格式:以可读的文本形式存储,每行包含一个参数的名称、大小和权重值。例如:
weight_name shape_size value1 value2 value3 ...
- 二进制格式:为了更高效地存储和读取,部分
.wts
文件采用二进制格式,能够减少文件大小并加快加载速度。
4. 文件特点
.wts
文件是一种用于存储深度学习模型权重的文件格式。其主要用途是将模型的参数(即权重)保存为一种可用于模型推理或部署的格式。以下是 .wts
文件的一些主要特点:
-
包含模型参数:
.wts
文件通常包含模型的所有训练后的参数,例如权重和偏置等,这些参数是模型进行预测的基础。 -
格式简洁:在
.wts
文件中,权重通常以文本格式或二进制格式存储,这样能够减少文件的存储占用,加快加载速度。 -
支持迁移:该文件格式使得用户能够将训练好的模型从一个框架(例如 PyTorch)转换到另一个框架(例如 TensorFlow Lite),便于模型在不同平台上的使用。
-
便于加载和使用:软件工具和框架可以方便地读取
.wts
文件,从而在推理时使用存储的权重,避免重新训练模型。
总结来说,.wts
文件在深度学习中的主要作用是保存和管理模型的权重参数,使得模型的共享和部署变得更加高效。
5. 使用场景
- 模型部署:将训练好的模型导出为
.wts
格式,便于在不同的编程环境或平台进行部署。 - 模型转换:在不同框架之间进行模型转换,如从 PyTorch 转换为 TensorFlow,使用
.wts
文件作为桥梁。 - 移动设备:部分深度学习模型在移动设备上推理时,使用优化后的
.wts
文件进行加载,以节省内存与提高效率。
6. 优点
- 跨平台兼容性:通过
.wts
文件,可以在不同的深度学习框架中共享和使用模型。 - 加载速度快:由于文件结构简单或采用二进制格式,权重加载速度较快,有利于在线推理。
- 资源占用少:相较于原始模型文件,
.wts
文件通常占用更少的存储空间。
7. 工具和框架支持
很多深度学习框架和工具的生态系统中都会包含对 .wts
文件的支持,例如:
- PyTorch:可通过自定义脚本将
.pt
文件转换为.wts
格式。 - TensorFlow、ONNX 和其他框架也可能提供一些工具或转换库来处理
.wts
文件。
8、python实现pt转wts
import sys # noqa: F401
import argparse
import os
import struct
import torch
def parse_args():
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
parser.add_argument('-w', '--weights', required=True,
help='Input weights (.pt) file path (required)')
parser.add_argument(
'-o', '--output', help='Output (.wts) file path (optional)')
parser.add_argument(
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'],
help='determines the model is detection/classification')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid input file')
if not args.output:
args.output = os.path.splitext(args.weights)[0] + '.wts'
elif os.path.isdir(args.output):
args.output = os.path.join(
args.output,
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
return args.weights, args.output, args.type
pt_file, wts_file, m_type = parse_args()
print(f'Generating .wts for {m_type} model')
# Load model
print(f'Loading {pt_file}')
# Initialize
device = 'cpu'
# Load model
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
if m_type in ['detect', 'seg', 'pose']:
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
delattr(model.model[-1], 'anchors')
model.to(device).eval()
with open(wts_file, 'w') as f:
f.write('{}\n'.format(len(model.state_dict().keys())))
for k, v in model.state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f', float(vv)).hex())
f.write('\n')
这段代码的主要功能是将 PyTorch 模型的参数(权重)从 .pt
文件转换为 .wts
格式文件。接下来我将逐步分解并详细解释代码。
-
定义
parse_args
函数:def parse_args():
该函数用于解析命令行参数。
-
创建参数解析器:
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
使用
argparse
库创建一个解析器,描述该程序的功能为将.pt
文件转换为.wts
文件。 -
添加参数:
-
权重文件参数:
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
该参数是必须的,用于指定输入的权重文件路径。
-
输出文件参数:
parser.add_argument( '-o', '--output', help='Output (.wts) file path (optional)')
该参数是可选的,用于指定输出文件的路径。
-
模型类型参数:
parser.add_argument( '-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg', 'pose'], help='determines the model is detection/classification')
该参数用于指定模型的类型,默认为
detect
,可选的值包括'detect'
、'cls'
、'seg'
和'pose'
。
-
-
解析命令行参数:
args = parser.parse_args()
解析命令行传入的参数。
-
验证输入文件:
if not os.path.isfile(args.weights): raise SystemExit('Invalid input file')
检查输入的权重文件是否存在,如果不存在,则输出错误信息并退出程序。
-
处理输出文件路径:
if not args.output: args.output = os.path.splitext(args.weights)[0] + '.wts' elif os.path.isdir(args.output): args.output = os.path.join( args.output, os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
- 如果没有指定输出文件,则默认将输出文件名设置为与输入文件同名但扩展名为
.wts
。 - 如果指定的输出路径是一个目录,则将输出文件名与输入文件名结合,生成完整的输出路径。
- 如果没有指定输出文件,则默认将输出文件名设置为与输入文件同名但扩展名为
-
返回文件路径和模型类型:
return args.weights, args.output, args.type
-
获取从命令行解析的参数:
pt_file, wts_file, m_type = parse_args()
-
打印模型生成信息:
print(f'Generating .wts for {m_type} model')
-
加载模型:
print(f'Loading {pt_file}') device = 'cpu' model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
从指定的
.pt
文件加载模型,确保将模型加载到 CPU,并转为 FP32(单精度浮点数)。 -
处理模型的锚点:
if m_type in ['detect', 'seg', 'pose']: anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None] delattr(model.model[-1], 'anchors')
如果模型类型是检测、分割或姿态,则计算锚点网格,并从模型中删除锚点属性。
-
切换到评估模式:
model.to(device).eval()
-
写入输出文件:
with open(wts_file, 'w') as f: f.write('{}\n'.format(len(model.state_dict().keys()))) for k, v in model.state_dict().items(): vr = v.reshape(-1).cpu().numpy() f.write('{} {} '.format(k, len(vr))) for vv in vr: f.write(' ') f.write(struct.pack('>f', float(vv)).hex()) f.write('\n')
- 打开输出文件,写入模型参数的数量。
- 遍历模型的所有参数,按格式写入每个参数的名称和大小,以及将每个值转换为十六进制字符串的格式。
这段代码的主要功能是将 PyTorch 模型的参数从 .pt
文件转换为 .wts
文件。通过命令行参数解析,用户可以指定输入权重文件、输出文件和模型类型。代码加载模型后对参数进行处理,并最终以特定格式保存到输出文件中。该程序为模型的迁移和使用提供了便利,特别是在需要将模型权重转换为其他兼容格式时。
导出文件内容如下
9.总结
.wts
文件在深度学习模型的存储、共享和部署中起到了重要的作用。它允许用户将复杂的模型权重信息简化为一种便于管理和迁移的格式,使得模型的使用变得更加高效和灵活。在实际应用中,了解和使用 .wts
文件格式有助于优化模型的开发流程。