PyTorch Lightning 最简单的训练,推理,导出onnx
flyfish
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
class SimpleModel(pl.LightningModule):
def __init__(self, input_dim, output_dim):
super(SimpleModel, self).__init__()
self.save_hyperparameters()
self.layer = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
# 创建一些示例数据
input_dim = 10
output_dim = 1
X = torch.randn(100, input_dim)
y = torch.randn(100, output_dim)
# 创建 DataLoader
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32)
# 初始化模型实例
model = SimpleModel(input_dim=input_dim, output_dim=output_dim)
# 初始化 PyTorch Lightning 的训练器
trainer = pl.Trainer(max_epochs=10)
# 训练模型
trainer.fit(model, dataloader)
# 保存模型检查点
checkpoint_path = 'model_checkpoint.ckpt'
trainer.save_checkpoint(checkpoint_path)
# 加载已保存的检查点,提供默认参数
model = SimpleModel.load_from_checkpoint(checkpoint_path, input_dim=input_dim, output_dim=output_dim)
# 设置模型为评估模式
model.eval()
# 创建示例输入
example_input = torch.randn(1, input_dim)
# 导出为 ONNX
onnx_path = 'model.onnx'
torch.onnx.export(
model, # 要导出的模型
example_input, # 示例输入
onnx_path, # ONNX 文件的路径
input_names=["input"], # 输入名
output_names=["output"], # 输出名
dynamic_axes={ # 动态轴设置
"input": {0: "batch_size"},
"output": {0: "batch_size"}
},
opset_version=11 # ONNX opset 版本
)
print(f"Model exported to {onnx_path}")