模型fp16保存,模型训练需要amp混合精度训练。
import torch
import torchvision
import model
device = torch.device('cuda')
scale_factor = 8
model = model.enhance_net_nopool(scale_factor)
checkpoint = torch.load(r"E:\NEt\Zero-DCE_++\Zero-DCE++\snapshots_Zero_DCE++\Epoch5.pth", map_location=device)
model.load_state_dict(checkpoint)
model.eval()
model.cuda().half()
img = torch.randn(1, 3, 3000, 4096, requires_grad=False,dtype=torch.float16).cuda()
# 创建 enhance 方法的封装函数
def enhance_image(input):
with torch.no_grad():
input = input.detach() # 设置输入不需要梯度
for param in model.parameters():
param.requires_grad = False # 设置模型参数不需要梯度
enhance_output, _ = model(input)
return enhance_output
traced_script_module = torch.jit.trace(enhance_image, img)#
output = traced_script_module(img)
print(output)
# 保存转换后的模型为 .pt
traced_script_module.save('./zerodec++_scale8.torchscript')#保存路径
pytorch模型训练之fp16、apm、多GPU模型、梯度检查点(gradient checkpointing)显存优化等 - 知乎 (zhihu.com)
1.全精度训练
import torch
model = torch.nn.Linear(D_in, D_out)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
out = model(img)
loss = LOSS(out, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
2.半精度训练,模型体积减半,loss可能为nan
将,模型和输入数据同时half()
import torch
model = torch.nn.Linear(D_in, D_out).half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
out = model(img.half())
loss = LOSS(out, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
3.混合精度训练
1)输入数据不需要再转换为半精度
from torch.cuda.amp import GradScaler as GradScaler
from torch.cuda.amp import autocast # 混合精度
from apex import amp
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
# GradScaler对象用来自动做梯度缩放
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# 在autocast enable 区域运行forward
with autocast():
# model做一个FP16的副本,forward
output = model(input)
loss = loss_fn(output, target)
# 用scaler,scale loss(FP16),backward得到scaled的梯度(FP16)
scaler.scale(loss).backward()
# scaler 更新参数,会先自动unscale梯度
# 如果有nan或inf,自动跳过
scaler.step(optimizer)
# scaler factor更新
scaler.update()
opt_level="O2"
opt_level:00相当于原始的单精度训练。01在大部分计算时采用半精度,但是所有的模型参数依然保持单精度,对于少数单精度较好的计算(如softmax)依然保持单精度。02相比于01,将模型参数也变为半精度。03基本等于最开始实验的全半精度的运算。值得一提的是,不论在优化过程中,模型是否采用半精度,保存下来的模型均为单精度模型,能够保证模型在其他应用中的正常使用。