0. 参考
- pytorch官方文档:https://pytorch.org/docs/stable/generated/torch.func.vmap.html#torch-func-vmap
- 关于if语句如何执行:https://github.com/pytorch/functorch/issues/257
1. 问题背景
-
笔者现在需要执行如下的功能:
root_ls = [func(x,b) for x in input]
因此突然想到pytorch或许存在对于自定义的函数的向量化执行的支持 -
一顿搜索发现了
from functorch import vmap
这种好东西,虽然还在开发中,但是很多功能已经够用了
2. 具体例子
- 这里只介绍笔者需要的一个方面,
vmap
的其他支持还请参阅pytorch官方文档 - 自定义函数及其输入:
# 自定义函数
def func_2(t,b):
return torch.where((t>5.),
t*b,
-t)
# 输入
t = torch.tensor([1.,2.,3.,4.,5.,6.,7.,8.])
b = torch.tensor([1.],requires_grad=True)
- 注意1:自定义函数不要出现
if
,用torch.where
替代。至于为什么参阅这个issue,大概的原因是“if
isn’t a differentiability requirement;”,强行使用会报错error of Data-dependent control flow
-
然后对于
b
,我们需要扩张到和t
同样的大小:
b_extend = torch.expand_copy(b,size=t.shape) # 必须把b扩张到和t同一个size否则报错
-
利用
vmap
,它返回一个新的函数func_vec
,具有向量化执行的支持,也可以利用autograd
求导
# Use vmap() to construct a new function.
func_vec = vmap(func_2) # [N, D], [N, D] -> [N]
ans = func_vec(t,b_extend)
ans.sum().backward() # 等价于: ans.backward(torch.ones(b_extend.shape))
b_extend.grad # 可以预见:b的导数是t:在t>5.时导数是t,在t<=5.时导数是0
- 全部代码:
import torch
from functorch import vmap
# if分支isn't a differentiability requirement;
def func(t,b):
tmp = t*b
if tmp > 5: # error: Data-dependent control flow
root = t*b
else:
root = -t
return root
def func_2(t,b):
return torch.where((t>5.),
t*b,
-t)
t = torch.tensor([1.,2.,3.,4.,5.,6.,7.,8.])
b = torch.tensor([1.],requires_grad=True)
b_extend = torch.expand_copy(b,size=t.shape) # 必须把b扩张到和t同一个size否则报错
b_extend.retain_grad()
print(f"shape of t:{t.shape}, shape of b_extend:{b_extend.shape}")
# shape of t:torch.Size([8]), shape of b_extend:torch.Size([8])
# Use vmap() to construct a new function. # [D], [D] -> []
func_vec = vmap(func_2) # [N, D], [N, D] -> [N]
ans = func_vec(t,b_extend)
ans.sum().backward() # 等价于: ans.backward(torch.ones(b_extend.shape))
b_extend.grad # 可以预见:b的导数是t:在t>5.时导数是t,在t<=5.时导数是0
# tensor([0., 0., 0., 0., 0., 6., 7., 8.])
- 问题在于,它真的比
root_ls = [func(x,b) for x in input]
这种快吗?在笔者的设计中确实是使用vmap
更快一些,但是不见得总是好用,只是在pytorch中写大量的for
实在是太愚蠢了QAQ
感谢阅读,欢迎交流