半精度训练
from torch.cuda.amp import autocast,GradScaler# 导入所需要的库
def trian_one_epoch(model,optimizer,data_loader,device,epoch):
scaler=GradScaler()
model.train()
loss_function=torch.nn.CrossEntropyLoss()
optimizer.zero_grad()
for step,data in enumerate(data_loader):
mels,labels=data
mels=mels.to(device)
labels=labels.to(device)
'''
和正常训练的主要区别
pred=model(mels)
loss=loss_function(pred,label)
train_acc.update((pred,labels))
train_loss.update((pred,labels))
loss.backward()
optimizer.step()
optimizer.zero_grad()
'''
with autocast():
pred=model(mels)
loss=loss_function(pred,label)
train_acc.update((pred,labels))
train_loss.update((pred,labels))
if not torch.isfinite(loss):#最好加上这一句,因为半精度计算的时候,容易出现梯度消失的问题。
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
return loss,acc