import numpy as np
def softmax(x):
x_row_max = x.max(axis=-1)
x_row_max = x_row_max.reshape(list(x.shape)[:-1] + [1])
x = x - x_row_max
x_exp = np.exp(x)
x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1] + [1])
softmax = x_exp / x_exp_row_sum
return softmax
if __name__ == "__main__":
m = np.array([[5,5],[4,7]])
m = softmax(m)
print(m)
# m = m.sum(axis=-2)