在机器学习中,常碰见的一个函数就是softmax,形式如下
这个周末参加WPC,遇见一道注意力机制的偏机器学习题目。很感兴趣,但最终还是没做出来,后面才知道是考数值溢出的,原题如下
这里的α_ts是h对输入特征x1,x2,...,xn的softmax响应函数,即对每个x给予多少“注意力”。直观理解为——如果h对d维中某个维度更“关注”,而x在这个维度的表现也很强,那么分配到的注意力就更高,这里通过向量的点积表示。
因此输入n*d的特征矩阵x1,x2,...,xn,以及m*d的注意力参数矩阵h1,h2,...,hm,返回在m维注意力空间下的响应c1,c2,⋯,cm
这里的坑在于——exp是很容易爆的东西!
在python2中它的上限在709,对于题目中的100*100,是负荷不来的。
>>> math.exp(1)
2.718281828459045
>>> math.exp(10)
22026.465794806718
>>> math.exp(100)
2.6881171418161356e+43
>>> math.exp(709)
8.218407461554972e+307
>>> math.exp(710)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
OverflowError: math range error
解决方案:用z=x-max(x),代替原始softmax输入
由于softmax函数特效,对exp指数的增减常量系数C并不会改变函数输出,因此将输入取值区间变换到(-∞,0],即可有效防止上溢。又因为exp函数大于0,不会出现分母为0,自然消除了下溢。(参考《深度学习》第四章——数值计算)
最终解题代码如下
import os
import sys
import math
stdin = '''3 3 2
1.0 0.8
2.0 -1.0
-1.0 0.2
1.0 -1.0
0.1 0.9
1.1 -0.2'''
#stdin = '''1 1 2
#100 0
#100 0'''
# read input[n*d] and attribute[m*d]
n,m,d = 0,0,0
for idx,line in enumerate(stdin.split('n')): # sys.stdin
if idx == 0:
n,m,d = [int(x) for x in line.split(' ')]
inp = []
att = []
continue
if idx < n+1:
inp.append([float(x) for x in line.split(' ')])
if idx >= n+1:
att.append([float(x) for x in line.split(' ')])
# calcul alpha[m*n]
alpha = []
dot = lambda x,y: sum([x[i]*y[i] for i in range(len(x))])
for i in range(m):
alpha.append([0]*n)
total = 0
## regular method
#for j in range(n):
# alpha[i][j] = math.exp(dot(att[i],inp[j]))
# total += alpha[i][j]
## better method, z = x-max(x)
hx = [dot(att[i],inp[j]) for j in range(n)]
max_hx = max(hx)
hx = [x-max_hx for x in hx]
for j in range(n):
alpha[i][j] = math.exp(hx[j])
total += alpha[i][j]
## normalize
alpha[i] = [alpha[i][j] / total for j in range(n)]
for i in range(m):
c = [0]*d
for j in range(n):
a = alpha[i][j]
for k in range(d):
c[k] += a*inp[j][k]
print ' '.join(['{:.6f}'.format(x) for x in c])
怒刷两道水题,排到300名附近,哈哈哈廉颇老矣尚能饭否?