About
Softmax原型如下,常用于神经网络最后一层,将结果转化成(0,1),并且相加等于1的概率。
![防止上下溢出,可以把xi替换为(xi-x_max)](https://img-blog.csdnimg.cn/direct/c438d6bc724840ea9e430ec50a95478e.png#pic_center)
防止上下溢出,可以把xi替换为(xi-x_max)
Code Implementation
#pragma once
#include <iostream>
#include <vector>
#include <algorithm>
#include <math.h>
#include <numeric>
#include "common.h"
using namespace std;
#ifndef SOFTMAX_H
#define SOFTMAX_H
void vector_test();
void softmax_test();
void accumulate_test();
template <typename T>
vector<T> softmax(vector<T> x)
{
T max_x = *max_element(x.begin(), x.end());
T sum = 0;
vector<T> y;
for (T xi : x)
{
sum += exp(xi - max_x);
}
for (T xi : x)
{
y.push_back(exp(xi - max_x) / sum);
}
return y;
}
template <typename T>
vector<T> softmax1(vector<T> x)
{
T x_max = *max_element(x.begin(), x.end());
T sum = 0;
vector<T> b;
for_each(x.begin(), x.end(), [&](T xi) {b.push_back(exp(xi - x_max)); sum += exp(xi - x_max);});
for_each(b.begin(), b.end(), [&](T &bi) {bi = bi / sum;});
return b;
}
template <typename T>
vector<T> softmax2(vector<T> x)
{
T x_max = *max_element(x.begin(), x.end());
vector<T> b;
for_each(x.begin(), x.end(), [&](T xi) {b.push_back(exp(xi - x_max)); });
for_each(b.begin(), b.end(), [=](T &bi) { bi = bi / accumulate(b.begin(), b.end(), (T)0);});
return b;
}
template <typename T>
vector<T> softmax3(vector<T> x)
{
vector<T> y;
T x_max = *max_element(x.begin(), x.end());
for_each(x.begin(), x.end(), [&](T xi) {y.push_back(exp(xi - x_max));});
for_each(y.begin(), y.end(), [=](T &yi) {yi /= accumulate(y.begin(), y.end(), (T)0);});
print3(y);
return y;
}
#endif
#include <iostream>
#include <vector>
using namespace std;
#ifndef COMMON_H
#define COMMON_H
template <typename T>
void print(vector<T> y)
{
auto p = [](auto v) {cout << v << endl;};
for_each(y.begin(), y.end(), p);
}
template <typename T>
void print3(vector<T> x)
{
for (auto xi : x) {
cout << xi << endl;
}
cout << endl;
}
#endif
#include <iostream>
#include <vector>
#include <algorithm>
#include <math.h>
#include <numeric>
#include "common.h"
#include "softmax.h"
using namespace std;
void vector_test()
{
cout << "======vector_test begin======" << endl;
vector<int> vec;
vec.push_back(1);
vec.push_back(3);
vec.push_back(8);
vec.push_back(100);
for (int i = 0; i < vec.size(); i++)
{
cout << vec[i] << endl;
}
vec.pop_back();
for (int i = 0; i < vec.size(); i++)
{
cout << vec[i] << endl;
}
cout << "======vector_test end======" << endl;
}
void softmax_test()
{
cout << "======softmax_test begin======" << endl;
double test1[6] = { 1.9502, -2.125,2.60156,2.05078,-1.77539,-4.21875 };
vector<double> x(test1, test1 + 6);
vector<double> y = softmax(x);
vector<double> y1 = softmax1(x);
vector<double> y2 = softmax2(x);
print(y);
print(y1);
print(y2);
vector<double> y3 = softmax3(x);
cout << "======softmax_test end======" << endl;
}
void accumulate_test()
{
cout << "======accumulate_test begin======" << endl;
vector<int> arr{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 11 };
int sum = accumulate(arr.begin(), arr.end(), 0);
cout << sum << endl;
cout << "======accumulate_test end======" << endl;
}
#include "softmax.h"
int main()
{
vector_test();
accumulate_test();
softmax_test();
}
References
Item | link |
---|
template | https://baijiahao.baidu.com/s?id=1758590740974256916&wfr=spider&for=pc |
accumulate | https://blog.csdn.net/VariatioZbw/article/details/125257536 |
double and float | https://baijiahao.baidu.com/s?id=1766744281104093690&wfr=spider&for=pc |
vector | https://blog.csdn.net/Crocodile1006/article/details/131730798 |
for_each | https://blog.csdn.net/weixin_43165135/article/details/125526408 |
vector | https://blog.csdn.net/wkq0825/article/details/82255984 |
lambda | https://www.cnblogs.com/DswCnblog/p/5629165.html |
softmax c++ | https://blog.csdn.net/weixin_44285683/article/details/126499579? |
softmax c++ | https://blog.csdn.net/qq_21008741/article/details/124496100? |
To be improved
Description | Comments |
---|
function implementation shall not be inside header file | to be discussed |
“using nampespace std” is not a proper usage | https://www.bilibili.com/video/BV17u4y1N7F5/?spm_id_from=333.1007.0.0&vd_source=80aaa765b368d4c6a71a0c164cc979ed |
“cout<<endl;” will flash memory and bad for performance, use "cout << ‘/n’ " | https://www.bilibili.com/video/BV17u4y1N7F5/?spm_id_from=333.1007.0.0&vd_source=80aaa765b368d4c6a71a0c164cc979ed |