文章目录
1. 目的
训练 lenet 需要初始化 kernel 的 weight 和 bias,而使用 Xavier Glorot 初始化则需要计算
sqrt
(
6.0
f
a
n
i
n
+
f
a
n
o
u
t
)
\text{sqrt}(\frac{6.0}{fan_{in} + fan_{out}})
sqrt(fanin+fanout6.0)(均匀分布) 或
sqrt
(
2.0
f
a
n
i
n
+
f
a
n
o
u
t
)
\text{sqrt}(\frac{2.0}{fan_{in} + fan_{out}})
sqrt(fanin+fanout2.0)(高斯分布).(参考[1]). 为了完全用 C 语言实现 lenet 的训练, 避免依赖 C 标准库的数学库函数 sqrt()
, 考虑弄清楚 sqrt()
的原理, 手动实现一个"低配版": 精度有轻微误差,实现简单。
实现开根号的方法,粗略说有三种:
- 二分法
- 牛顿法
- 卡马克公式快速法
本文只考虑 n , n ∈ R + \sqrt{n}, n \in \R^+ n,n∈R+.
2 二分法求开根号
2.1 数学原理:单调函数
对于正实数 n ∈ R + n \in \R^+ n∈R+, 它的二次方根为 x = n x=\sqrt{n} x=n, 也就是使得 x 2 = n x^2=n x2=n 成立的数字。考察方程 f ( x ) = x 2 − n = 0 f(x)=x^2-n=0 f(x)=x2−n=0 的解:
- 如果 n > 1 n > 1 n>1, 则 n ∈ ( 1 , n ) \sqrt{n} \in (1, n) n∈(1,n), 是一个单调递增区间, s.t. f ( x ) \text{s.t.} f(x) s.t.f(x)有解
- 如果
0
<
n
<
1
0 < n < 1
0<n<1, 则
n
∈
(
n
,
1
)
\sqrt{n} \in (n, 1)
n∈(n,1), 也是一个单调递增区间,
s.t.
f
(
x
)
\text{s.t.} f(x)
s.t.f(x)有解
单调性使得我们可以用二分法求解 f ( x ) = x 2 − n = 0 f(x)=x^2-n=0 f(x)=x2−n=0, 从而得到答案 n = x \sqrt{n}=x n=x
2.2 代码实现:注意事项
比较相等
C语言使用 IEEE-754 标准来表示浮点数, 表示的数字可能和理论数字有误差, 因此判断浮点数相等时往往做差值的绝对值然后和 eps 比较, 小于eps就认为相等。
迭代求解
二分法是一个迭代求解算法, 可以手动设置迭代次数, 也可以设置比较精度 eps,迭代过程中精度误差小于 eps 就停止。本文的实现选择设置 eps 的方式。
防止溢出
本文给出的实现,是用 double 类型计算的。 计算两个数字中点时,有可能超过 double 类型最大值, 因此用先求差值的一半,再加到左端点的方式来计算中点。
特殊数字处理
n
<
0
n < 0
n<0, 直接返回。
n
=
0
n = 0
n=0 和
n
=
1
n = 1
n=1, 直接返回。
2.3 代码实现: 完整代码
#include <stdio.h>
#include <stdbool.h>
double m_fabs(double n)
{
return n >= 0.0 ? n : -n;
}
double m_sqrt(double n)
{
if (n == 0.0 || n == 1.0)
{
return n;
}
if (n < 0.f)
{
printf("Error: not supported n: %f\n", n);
return -1;
}
double left, right;
if (n > 1.0)
{
left = 1.0;
right = n;
}
else
{
left = n;
right = 1.0;
}
double left_v = left * left - n;
double right_v = right * right - n;
if (left_v * right_v > 0)
{
printf("Error: not exist sqrt for n=%f\n", n);
return -2;
}
const double eps = 1e-5;
while (left <= right)
{
printf("left=%f, right=%f\n", left, right);
double mid = left + (right - left) / 2.0;
double value = mid * mid;
if (value - n > eps)
{
right = mid;
}
else if (value - n < -eps)
{
left = mid;
}
else
{
return mid;
}
}
return 233;
}
int main()
{
double n;
while (true)
{
printf(">>> Please input an double number: ");
scanf("%lf", &n);
double ans = m_sqrt(n);
printf("sqrt(%lf) = %lf\n", n, ans);
}
return 0;
}
2.4 验证结果
base) zz@Legion-R7000P% gcc sqrt.c
(base) zz@Legion-R7000P% ./a.out
>>> Please input an double number: 9.0
left=1.000000, right=9.000000
left=1.000000, right=5.000000
sqrt(9.000000) = 3.000000
>>> Please input an double number: 0.04
left=0.040000, right=1.000000
left=0.040000, right=0.520000
left=0.040000, right=0.280000
left=0.160000, right=0.280000
left=0.160000, right=0.220000
left=0.190000, right=0.220000
left=0.190000, right=0.205000
left=0.197500, right=0.205000
left=0.197500, right=0.201250
left=0.199375, right=0.201250
left=0.199375, right=0.200313
left=0.199844, right=0.200313
left=0.199844, right=0.200078
left=0.199961, right=0.200078
sqrt(0.040000) = 0.200020
>>> Please input an double number: 0.01
left=0.010000, right=1.000000
left=0.010000, right=0.505000
left=0.010000, right=0.257500
left=0.010000, right=0.133750
left=0.071875, right=0.133750
left=0.071875, right=0.102813
left=0.087344, right=0.102813
left=0.095078, right=0.102813
left=0.098945, right=0.102813
left=0.098945, right=0.100879
left=0.099912, right=0.100879
left=0.099912, right=0.100396
left=0.099912, right=0.100154
sqrt(0.010000) = 0.100033
>>> Please input an double number: ^C
3. 牛顿法
3.1 数学原理:迭代求解
给定数字 a a a, 求 a \sqrt{a} a. 等价于求方程 f ( x ) = x 2 − a = 0 f(x)=x^2-a = 0 f(x)=x2−a=0 的解。
这个方程在 x 0 x_0 x0 点处的切线 L ( x 0 ) L(x_0) L(x0)方程为 f ( x ) − f ( x 0 ) = f ′ ( x 0 ) ( x − x 0 ) f(x)-f(x_0)=f'(x_0)(x-x_0) f(x)−f(x0)=f′(x0)(x−x0).
切线与 x x x 轴有交点, 也就是当 f ( x ) = 0 f(x)=0 f(x)=0, f ′ ( x 0 ) ( x − x 0 ) + f ( x 0 ) = 0 f'(x_0)(x-x_0) + f(x_0) = 0 f′(x0)(x−x0)+f(x0)=0
⇒ x − x 0 = − f ( x 0 ) / f ′ ( x 0 ) \Rightarrow x-x_0 = -f(x_0)/f'(x_0) ⇒x−x0=−f(x0)/f′(x0)
$\Rightarrow x = x_0 - f(x_0)/f’(x_0) = x_0 - (x_0^2-n)/2x_0 = (x_0 + a/x_0)/2 $
得到 x \sqrt{x} x 的第一个近似解 x 1 = ( x 0 + a x 0 ) / 2 x_1=(x_0+\frac{a}{x_0})/2 x1=(x0+x0a)/2.
通常
x
1
x_1
x1 的精度不足,也就是
x
1
2
x_1 ^ 2
x12 和
a
a
a 相差比较多,因此还需要继续迭代。迭代到第
n
n
n 次时:
$\Rightarrow x_{n+1} = x_{n} - \frac{f_n}{f’(x_n)} = \frac{1}{2} (x_n + \frac{a}{x_n}) $
只要此时 x n 2 {x_n}^2 xn2 和 a a a 足够接近, 或者迭代次数 n n n 足够大, 都可以停止迭代, 用 x n x_n xn 作为 a \sqrt{a} a.
3.2 代码实现
#include <stdio.h>
#include <stdbool.h>
double m_fabs(double n)
{
return n >= 0.0 ? n : -n;
}
double m_sqrt_newton(double a)
{
// x_{n+1} = \frac{1}{2} (x_n + \frac{a}{x_n})
double x = 1.0; // why?
double eps = 1e-5;
while (m_fabs(x * x - a) > eps)
{
printf("x = %lf\n", x);
x = (x + a / x) / 2.0;
}
return x;
}
int main()
{
double n;
while (true)
{
printf(">>> Please input an double number: ");
scanf("%lf", &n);
//double ans = m_sqrt(n);
//printf("sqrt(%lf) = %lf\n", n, ans);
double ans_newton = m_sqrt_newton(n);
printf("sqrt_newton(%lf) = %lf\n", n, ans_newton);
}
return 0;
}
3.3 结果
zz@Legion-R7000P% gcc sqrt.c
zz@Legion-R7000P% ./a.out
>>> Please input an double number: 9.0
x = 1.000000
x = 5.000000
x = 3.400000
x = 3.023529
x = 3.000092
sqrt_newton(9.000000) = 3.000000
>>> Please input an double number: 0.04
x = 1.000000
x = 0.520000
x = 0.298462
x = 0.216241
x = 0.200610
sqrt_newton(0.040000) = 0.200001
>>> Please input an double number: 0.01
x = 1.000000
x = 0.505000
x = 0.262401
x = 0.150255
x = 0.108404
x = 0.100326
sqrt_newton(0.010000) = 0.100001
>>> Please input an double number: ^C
4. 卡马克快速法
4.1 原理
卡马克在雷神之锤游戏中给出了求平方根倒数的一种非常trick的代码实现。把它再求倒数, 就得到开根号结果。
它其实是一种混合方法: 一部分是牛顿法, 另一部分是对数函数的近似。其中牛顿迭代部分用于提升精度, 对数函数的逼近则和 IEEE-754 浮点数表示法紧密结合。
使用的近似公式是 l o g 2 ( 1 + x ) ≈ x + k log_2(1+x) \approx x + k log2(1+x)≈x+k. 见参考[4].
4.2 代码实现
由于 Carmack 快速求平方根的倒数法, 本身目的就是要尽可能快, 因此使用 float 类型而不是 double 类型。
#include <stdio.h>
double m_sqrt_carmack(double n)
{
int i;
float x2, y;
const float threehalfs = 1.5f;
x2 = n * 0.5f;
y = (float)n;
i = *(int*)&y;
i = 0x5f3759df - (i >> 1);
y = *(float *)&i;
y = y * (threehalfs - (x2 * y * y)); // 1st iteration
y = y * (threehalfs - (x2 * y * y)); // 2nd iteration
return 1.0 / y;
}
int main()
{
double n;
while (true)
{
printf(">>> Please input an double number: ");
scanf("%lf", &n);
double ans_carmack = m_sqrt_carmack(n);
printf("sqrt_carmack(%lf) = %lf\n", n, ans_carmack);
}
return 0;
}
4.3 结果
zz@Legion-R7000P% gcc sqrt.c
zz@Legion-R7000P% ./a.out
>>> Please input an double number: 9.0
sqrt_carmack(9.000000) = 3.000006
>>> Please input an double number: 0.04
sqrt_carmack(0.040000) = 0.200001
>>> Please input an double number: 0.01
sqrt_carmack(0.010000) = 0.100000
>>> Please input an double number: ^C
5. 完整代码
// Author: Zhuo Zhang <imzhuo@foxmail.com>
// Homepage: https://github.com/zchrissirhcz
#include <stdio.h>
#include <stdbool.h>
double m_fabs(double n)
{
return n >= 0.0 ? n : -n;
}
double m_sqrt(double n)
{
if (n == 0.0 || n == 1.0)
{
return n;
}
if (n < 0.f)
{
printf("Error: not supported n: %f\n", n);
return -1;
}
double left, right;
if (n > 1.0)
{
left = 1.0;
right = n;
}
else
{
left = n;
right = 1.0;
}
double left_v = left * left - n;
double right_v = right * right - n;
if (left_v * right_v > 0)
{
printf("Error: not exist sqrt for n=%f\n", n);
return -2;
}
const double eps = 1e-5;
while (left <= right)
{
printf("left=%f, right=%f\n", left, right);
double mid = left + (right - left) / 2.0;
double value = mid * mid;
if (value - n > eps)
{
right = mid;
}
else if (value - n < -eps)
{
left = mid;
}
else
{
return mid;
}
}
return 233;
}
double m_sqrt_newton(double a)
{
// x_{n+1} = \frac{1}{2} (x_n + \frac{a}{x_n})
double x = 1.0; // why?
double eps = 1e-5;
while (m_fabs(x * x - a) > eps)
{
printf("x = %lf\n", x);
x = (x + a / x) / 2.0;
}
return x;
}
double m_sqrt_carmack(double n)
{
int i;
float x2, y;
const float threehalfs = 1.5f;
x2 = n * 0.5f;
y = (float)n;
i = *(int*)&y;
i = 0x5f3759df - (i >> 1);
y = *(float *)&i;
y = y * (threehalfs - (x2 * y * y)); // 1st iteration
y = y * (threehalfs - (x2 * y * y)); // 2nd iteration
return 1.0 / y;
}
int main()
{
double n;
while (true)
{
printf(">>> Please input an double number: ");
scanf("%lf", &n);
double ans = m_sqrt(n);
printf("sqrt(%lf) = %lf\n", n, ans);
double ans_newton = m_sqrt_newton(n);
printf("sqrt_newton(%lf) = %lf\n", n, ans_newton);
double ans_carmack = m_sqrt_carmack(n);
printf("sqrt_carmack(%lf) = %lf\n", n, ans_carmack);
}
return 0;
}