import torch
from functorch import vmap
def torch_mul(x: torch.Tensor, w: torch.Tensor, b: torch.Tensor): # pair-wise multiple
x = w * x + b
x = w * x + b
return w * x + b
gpu_id=0
num=int(192)
dim=int(256)
device = torch.device(f"cuda:{gpu_id}" if (torch.cuda.is_available() and (gpu_id >= 0)) else "cpu")
xs = torch.rand((num, dim), device=device)
ws = torch.rand((num, dim), device=device)
b = torch.rand(dim, device=device)
#vmap函数的两个参数in_dims和out_dims的含义:
#in_dims和out_dims用于表示输入和输出进行并行运算的维度,所以个数和输入输出的元素个数有关
#因为torch_mul的输入元素为x,w,b三个,所以in_dims为三元组(0,0,0)或者(1,1,0)等,当所有元素全部为0时,可以简写为in_dims=0
#在本算例中,指定in_dims为(1,1,0),说明对输入参数x和w均在dim=1上进行并行运算,所以每次x和w这个二维矩阵都是按列进行输入,每次输入一列,所以对函数torch_mul输入的x和w每次都是一个192元素的向量
#而b因为in_dims为(1,1,0),所以按dim=0进行并行运算,因为b是一个256元素的向量,所以每次输入一个标量,根据“广播(broadcast)”作用,标量会按照x和w的尺度扩充,变为192元素,从而实现相加
#如果in_dims为(1,1,None),则表示b不进行并行,所以每次将b作为一个256元素的整体带入计算,这样就会出现192维度和256维度不匹配的情况
#out_dims=0表示输出结果在dim=0上进行合并,由于单个线程的输出为192元素的向量,总共有256个线程,所以将256个输出在第一维度上合并,形成最终(256,192)的输出结果
vmap_func = vmap(torch_mul, in_dims=(1, 1, 0), out_dims=0)
ys = vmap_func(xs, ws, b)
print(ys.shape)
functorch.vmap()共有三个参数,第一个参数是需要进行并行的函数,直接输入函数名即可。第二个和第三个参数分别为in_dims和out_dims,表示原函数的输入和输出需要进行并行计算的维度。
in_dims和out_dims均为元组(元组中元素个数为1个时,可以直接输入标量)。元组中元素的个数和原函数的输入输出的元素个数有关。因为torch_mul的输入元素为x,w,b三个,所以in_dims为三元组(0,0,0)或者(1,1,0)这样的三元组,当所有元素全部为0时,可以简写为in_dims=0。而torch_mul的输出只有一个值,所以out_dims为标量,表示所有并行线程的输出结果在此维度上实现合并。
在本算例中,指定in_dims为(1,1,0),说明对输入参数x和w均在dim=1上进行并行运算,所以每次x和w这个二维矩阵都是按列进行输入,每次输入一列,所以对函数torch_mul输入的x和w每次都是一个192元素的向量。而b因为in_dims为(1,1,0),即b所对应的维度为dim=0,所以在第0个维度上进行并行运算,因为b是一个256元素的向量,“第0个维度”即为256个元素所在的维度,所以每次输入256个元素中的一个(一个标量)。但是根据torch的张量计算法则,在“广播(broadcast)”作用下,标量会按照x和w的尺度扩充,复制为192个相同的元素,从而实现相加。
如果本例中,in_dims被改为(1,1,None),则表示x和w在dim=1上进行并行计算,但b“不在某一个维度上”参与并行。所以实际并行运算的时候,每个线程都将b作为“具有256元素的一个向量”整体带入计算,这样就会在每个线程计算的过程中,出现192维度和256维度不匹配的情况,从而报错。
out_dims=0表示输出结果在dim=0上进行合并,在本例中,由于单个线程的输出为192元素的向量,总共有256个线程,所以将256个输出在第一维度上合并,形成最终(256,192)的输出结果.