def calculate_loss_and_accuracy(outputs, labels, device):
"""
计算非pad_id的平均loss和准确率
:param outputs:
:param labels:
:param device:
:return:
"""
logits = outputs[0] # 每个token用来预测下一个token的prediction_score,维度:[batch_size,token_len,voca_size]
# 用前n-1个token,预测出第n个token
# 用第i个token的prediction_score用来预测第i+1个token。