0. 序
本篇分析linux内核中的struct reciprocal_value
和struct reciprocal_value_adv
背后的数学原理,参考论文Division by Invariant Integers using Multiplication。
1. basics
核心想法是这样,在除数是常数的情况下,我们希望用乘法和移位指令来代替除法(通常情况下这是更快的)。这可以发生在编译时(作为一种编译优化),或者发生在运行时(软件实现)。整数除法有不同的情况:有符号还是无符号,向0舍入还是向无穷舍入?原论文讨论了许多的情况。我们只分析最常见的一种(linux也只实现了这一种):无符号除法,并向0舍入。
我们假设计算机使用N位bit来表示一个整数,那么可以把问题数学化为:给定除数
d
,
1
≤
d
≤
2
N
−
1
d, 1 \leq d \leq 2^N - 1
d,1≤d≤2N−1,希望找到合适的正整数
m
,
l
m, l
m,l,使得对
∀
n
,
1
≤
n
≤
2
N
−
1
\forall n, 1\leq n \leq 2^N-1
∀n,1≤n≤2N−1,有下式成立
[
n
d
]
=
[
m
n
2
N
+
l
]
[\frac{n}{d} ]= [\frac{mn}{2^{N +l}}]
[dn]=[2N+lmn]
这样我们便可以把除法转化为一个乘法和移位操作了。不过注意这里的乘法是数学上的乘法,换到计算机里,实际需要做的是一个宽度为2N的乘法(即不能丢掉高位bit)。
直观上来看,为了使上式成立,我们需要尽可能得使
m
∗
d
m*d
m∗d接近
2
N
+
l
2^{N+l}
2N+l,取
n
=
d
n=d
n=d,可知
k
=
m
∗
d
−
2
N
+
l
≥
0
k = m*d - 2^{N+l}\geq 0
k=m∗d−2N+l≥0,设
n
=
q
d
+
r
,
0
≤
r
≤
d
−
1
n = qd + r, 0\leq r \leq d - 1
n=qd+r,0≤r≤d−1。上式成立等价于
k
n
2
N
+
l
d
+
r
d
<
1
\frac{kn}{2^{N+l}d}+\frac{r}{d} < 1
2N+ldkn+dr<1
若取
k
≤
2
l
k \leq 2^l
k≤2l,则有
k
n
2
N
+
l
d
+
r
d
<
2
l
∗
2
N
2
N
+
l
∗
d
+
d
−
1
d
=
1
\frac{kn}{2^{N+l}d}+\frac{r}{d} < \frac{2^l*2^N}{2^{N+l}*d} + \frac{d-1}{d} = 1
2N+ldkn+dr<2N+l∗d2l∗2N+dd−1=1
因此我们得到了原论文中的定理4.2
为了保证这样的m存在,我们只需
2
l
+
1
≥
d
2^l + 1 \geq d
2l+1≥d,因此通常取
l
=
⌈
l
o
g
2
d
⌉
m
=
[
2
N
+
l
d
]
+
1
可得范围
:
2
N
+
1
≤
m
≤
2
N
+
1
−
1
l = \lceil log_2d\rceil \\ m = [\frac{2^{N+l}}{d}]+1 \\ 可得范围 : 2^N+1 \leq m \leq 2^{N+1} - 1
l=⌈log2d⌉m=[d2N+l]+1可得范围:2N+1≤m≤2N+1−1
可以看到,m是没办法用Nbit来表示的。因此定义
m
′
=
m
−
2
N
m' = m - 2^N
m′=m−2N。并定义
H
(
x
y
)
H(xy)
H(xy)表示乘法的高N位bit,
L
(
x
y
)
L(xy)
L(xy)表示乘法的低N位bit,这样有
[
n
q
]
=
[
m
n
2
N
+
l
]
=
[
m
′
n
2
N
+
l
+
n
2
l
]
=
[
H
(
m
′
n
)
+
n
2
l
+
L
(
m
′
n
)
2
N
+
l
]
=
[
H
(
m
′
n
)
+
n
2
l
]
最后一步用到了:
L
(
m
′
n
)
2
N
+
l
<
1
2
l
[\frac{n}{q}] = [\frac{mn}{2^{N +l}}] = [\frac{m'n}{2^{N+l}}+\frac{n}{2^l}] = [\frac{H(m'n)+n}{2^l} +\frac{L(m'n)}{2^{N+l}}] = [\frac{H(m'n)+n}{2^l}] \\ 最后一步用到了:\frac{L(m'n)}{2^{N+l}} < \frac{1}{2^l}
[qn]=[2N+lmn]=[2N+lm′n+2ln]=[2lH(m′n)+n+2N+lL(m′n)]=[2lH(m′n)+n]最后一步用到了:2N+lL(m′n)<2l1
我们可以提前算好
m
′
,
l
m',l
m′,l,因此每次除法的开销变为一次高N位bit的乘法,一次加法,一次移位操作。但这里有一个麻烦的地方是这里的加法可能会溢出。按照论文中记
t
1
=
H
(
m
′
n
)
t_1 = H(m'n)
t1=H(m′n),我们改写结果为
[
H
(
m
′
n
)
+
n
2
l
]
=
[
t
1
+
[
n
−
t
1
2
]
2
l
−
1
]
[\frac{H(m'n)+n}{2^l}] = [\frac{t_1+[\frac{n-t_1}{2}]}{2^{l-1}}]
[2lH(m′n)+n]=[2l−1t1+[2n−t1]]
这样避免了溢出的问题,但开销增加到了两次加减法,两次移位,一次高位乘法。另外还需要额外考虑
l
=
0
l=0
l=0的情况(此时
d
=
1
d=1
d=1),这时候这个变换就不合理了。最终得到如下的linux代码(reciprocal_div.c
)
struct reciprocal_value {
u32 m; // 这就是前面提到的 m'
u8 sh1, sh2; // 当 d != 1时,sh1 = 1, sh2 = l - 1
};
struct reciprocal_value reciprocal_value(u32 d) // 预处理函数
{
struct reciprocal_value R;
u64 m;
int l;
l = fls(d - 1);
m = ((1ULL << 32) * ((1ULL << l) - d));
do_div(m, d);
++m;
R.m = (u32)m;
R.sh1 = min(l, 1);
R.sh2 = max(l - 1, 0);
return R;
}
static inline u32 reciprocal_divide(u32 a, struct reciprocal_value R)
{ // 除法替换为乘法,加减法和移位操作
u32 t = (u32)(((u64)a * R.m) >> 32);
return (t + ((a - t) >> R.sh1)) >> R.sh2;
}
2. improvement
在此基础上,论文中提出了若干优化手段:
- 若选取的 m m m是偶数,那么可以用 m / 2 m/2 m/2替换 m m m,同时 l − 1 l-1 l−1替换 l l l,结果不变,但此时 m m m就可以用Nbit来表示了,这样只需要一次高位乘法,一次移位就可以得到结果。
- 若
d
d
d是偶数,那么对某个正整数
e
e
e,有
[
n
d
]
=
[
[
n
/
2
e
]
d
/
2
e
]
[\frac{n}{d}]=[\frac{[n/2^e]}{d/2^e}]
[dn]=[d/2e[n/2e]],这样每次做除法前先将
n
n
n右移
e
e
e位,这样在basics中证明定理4.2时的条件可以放宽为
1
≤
n
<
2
N
−
e
1 \leq n < 2^{N-e}
1≤n<2N−e,这样便可以把
k
k
k的取值范围放宽为
k
≤
2
l
+
e
k \leq 2^{l+e}
k≤2l+e,
因此m会有更宽的选择范围。因此也更容易将m选到偶数(这样说不太准确,因为此时 l = [ l o g 2 d 2 e ] l = [log_2{\frac{d}{2^e}}] l=[log22ed],因此选择宽度并没有发生变化,但是因为 l l l变小了,因此选择范围确实发生了变化,总的来说增加了选到 m m m是偶数的概率) 。
这对应到linux的代码中为
struct reciprocal_value_adv {
u32 m;
u8 sh, exp;
bool is_wide_m; // 如果通过优化,使得m能够用32bit表示,那么is_wide_m为false,
// 此时上面的u32 m表示原始的m,否则is_wide_m为true,u32 m表示的是m'
};
// 第一次调用时传入prec为32,如果返回的is_wide_m为true,那么caller判断d是否为偶数,如果是
// 则再次调用reciprocal_value_adv(d >> e, 32 - e)
struct reciprocal_value_adv reciprocal_value_adv(u32 d, u8 prec)
{
struct reciprocal_value_adv R;
u32 l, post_shift;
u64 mhigh, mlow;
/* ceil(log2(d)) */
l = fls(d - 1);
/* NOTE: mlow/mhigh could overflow u64 when l == 32. This case needs to
* be handled before calling "reciprocal_value_adv", please see the
* comment at include/linux/reciprocal_div.h.
*/
WARN(l == 32,
"ceil(log2(0x%08x)) == 32, %s doesn't support such divisor",
d, __func__);
post_shift = l;
mlow = 1ULL << (32 + l);
do_div(mlow, d);
mhigh = (1ULL << (32 + l)) + (1ULL << (32 + l - prec));
do_div(mhigh, d); // m的取值范围为(mlow, mhigh]
for (; post_shift > 0; post_shift--) { // 选出范围中包含2的幂次最大的m
u64 lo = mlow >> 1, hi = mhigh >> 1;
if (lo >= hi)
break;
mlow = lo;
mhigh = hi;
}
R.m = (u32)mhigh;
R.sh = post_shift;
R.exp = l;
R.is_wide_m = mhigh > U32_MAX;
return R;
}
额外的一点是,linux没有实现
l
=
32
l=32
l=32的情形(此时
d
>
2
31
d > 2^{31}
d>231),因为这个时候需要128位的除法来计算mlow,mhigh
。这种情况下软件模拟来进行128位除法的开销太大了。
That’s all.