在做模型迁移时,会遇到源代码中使用的混合精度模块(pytorch版本1.6)与当前使用的pytorch版本(pytorch版本1.5)不匹配,造成程序无法正常运行。此时,需要将源码中使用的混合精度模块#import torch.cuda.amp as amp模块替换为from apex import amp 并将相应的源代码进行修改后即可正常使用,(参照:https://nvidia.github.io/apex/amp.html)例子如下如下:
# Declare model and optimizer as usual, with default (FP32) precision model = torch.nn.Linear(D_in, D_out).cuda() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # Allow Amp to perform casts as required by the opt_level model, optimizer = amp.initialize(model, optimizer, opt_level="O1")... # loss.backward() becomes: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward()
举个真实代码部分修改的例子:
# Perform the forward pass and compute the loss
#with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION):
#preds = model(inputs)
#loss = loss_fun(preds, labels_one_hot)
preds = model(inputs)
loss = loss_fun(preds, labels_one_hot)
# Perform the backward pass and update the parameters
optimizer.zero_grad()
#scaler.scale(loss).backward()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()