出发点
最近写到了chatglm-6B的代码解读,发现gelu这么常用的激活函数的来龙去脉我都没搞明白,所以去翻了一下论文和代码,不翻不知道,一翻吓一跳啊,原来我们经常看到的代码竟然是近似形式,它还有很多的背景知识,一起来看一下吧。
补充知识
正态分布与误差函数
伯努利分别
gelu公式
代码实现
class GELUActivation(nn.Module):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
super().__init__()
if use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu# 近似求解1
def _gelu_python(self, input: Tensor) -> Tensor:# 精确求解
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:# 近似求解1
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class QuickGELUActivation(nn.Module):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:# 近似求解2
return input * torch.sigmoid(1.702 * input)
class FastGELUActivation(nn.Module):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:# 近似求解3
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
近似推导
结束语
本文探究了gelu的来龙去脉,查看了它的4种实现方式。