How does vmap() function work?

According to JAX docsfile on vmapjax.vmap(function, in_axes=0, out_axes=0) returns a function which maps the function one specified over using in_axes and stack them together using out_axes . The concept is simple but it took me a while to understand when in_axes or out_axes is not set as default.

To understand what it is really doing when in_axes and out_axes change, I’m creating simple arrays with unique numbers in matrix forms called a and b and passing them into a simple function numpy.sum(axis=None) to do default element-wise addition.

  • Let’s see an example with in_axes as default.

The in_axes is a tuple with two number, since there are two inputs a and bin_axes[0] indicates which axis to focus on for a, it can take in different values: None, 0 or -2 (map over 0th axis also known as -2 axis in this case, the operation will be performed row-wise), 1 or -1 (map over 1st axis also known as -1 axis in this case, the operation will be performed col-wise). Given a only has two dimensions, there won’t be any additional axis available for vmap to work on. in_axes[1] can take in four different values same as in the case of a, but the computation would be on b. Note at least one input in in_axes should be integer instead of None. out_axes simply means how one would like to stack the results, either by row ( as 0) or column (as 1).

  • Change in_axes to customized integers, what would the function return?

The operation is done on the specified axis. Overall, if in_axes has all integers, the shape of the output remains the same as of (2,2) just like what you would see in regular np.add(a,b).

  • Mutate some integer in in_axes to None (note that at least one of the input should be integer), what would the result be like?

The example outputs are all of shape (2,2,2) instead of the original shape (2,2), why is that? We see the in_axes for a is always None, this indicates the operation is going to pick whatever axis specified in b and apply it on the all rows of a. Given a is of shape (2,2) and b has two rows and two columns, you can think of each operation produced two copies of a, then apply row or col of b on a.

  • A more complicated example

The above example is simple, each input is of the same shape and has two dimensions. What if we have two matrices each of shape (2,3,5,7,9)? What would we expect when changing the in_axes? Note that although a and b matrices are initiated differently, they have the same shape.

Based on user supplied function, only inputs with matching dimensions will work together. Try vmap(jnp.add, in_axes = (0, 1) , out_axes = 0 )(a,b).shape it would tell you 0-axis of a is of size 2 while 1-axis of b is of size 3, therefore, the computation can’t happen. Then it is obvious that the operation takes out whatever axis specified and stack them as 0-th axis.

Let’s set one of the in_axes as None and change the out_axes.

You realized there is an additional axis in the output just like in the simple example. But the returned shape can be a bit confusing. Is there a pattern? I summarized the pattern in an intuitive way in the table below, though it might also be possbile to infer the output shape in a different way. First, write down the output shape when out_axes=0. Then, sequentially swap output axes. The output shape of out_axes ==n equals to swapping the (n-1)th index with nth index in out_axes ==(n-1) shape.

Summary of in_axes and out_axes in vmap()

Imagine you have two matrices that the customized function is applied on, A1 is of shape (a1,b1,c1,….) with x dimensions, A2 is of shape (a2,b2,c2,….) with y dimensions, and the output has z dimensions.

  1. The values used for in_axesin_axes = (x',y'), where x’ is an integer between [-x, x), y is an interger between [-y, y). x’ or y’ can also be None, but at last there should be one integer value in (x’, y’).

  2. The scenario of in_axes = (n,None) will apply take out the nth index from the first matrix and apply (1, remianing shape of the first matrix…) on the 2nd matrix. e.g A1.shape = (a1,b1,c1), A2.shape = (a2,b2,c2). in_axes = (1,None) will apply vectors of shape (1,a1,c1) to all rows in A2, while in_axes = (2,None) will apply vectors of shape (1,a1,b1) to the all rows in A2.

  3. out_axes is an interger of value between (-z, z].

  4. To figure out the change of shape when you change out_axes: First, write down the output shape when out_axes=0. Then, sequentially swap output axes. The shape of the output for out_axes ==n equals to swapping the (n-1)th index with nth index in out_axes ==(n-1) output shape.

I want to thank Arkadij Kummer for giving valuable advice on making this blog post easier to understand :) .

  • 20
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

张博208

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值