在引入BertAdam,通过以下import,然而会出现UserWarning问题。
from pytorch_pretrained_bert import BertAdam
UserWarning: This overload of add_ is deprecated:
add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
add_(Tensor other, *, Number alpha) (Triggered internally at …/torch/csrc/utils/python_arg_parser.cpp:1485.)
next_m.mul_(beta1).add_(1 - beta1, grad)
因此在自己的代码里引入class BertAdam(Optimizer)源代码部分,改写以下代码
# 改之前
# next_m.mul_(beta1).add_(1 - beta1, grad)
# next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
# 改之后
next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
此时存在cannot find reference ‘required’ in torch.optim.optimizer问题,也就是from torch.optim.optimizer import required此行代码存在问题,解决办法是将optimizer.py中涉及required代码段之间写入自己的代码中,而不import,此时代码也没有UserWarning。
class _RequiredParameter:
"""Singleton class representing a required parameter for an Optimizer."""
def __repr__(self) -> str:
return "<required parameter>"
required = _RequiredParameter()