Triton 面向的是数据块编程,屏蔽了大多数硬件细节,降低了开发门槛。开发人员可以专注于数据块划分和算法设计。通过合理的算法设计,Triton 实现的算子完全有可能在性能上超越 pytorch 中的 cuda 实现。
isin 算子的功能与接口
def isin(in0, in1, *, assume_unique: bool = False, invert: bool = False) -> torch.Tensor
功能:判断 in0 的每个元素是否在 in1 出现过,返回一个和 in0 形状相同的布尔类型 Tensor。
assume_unique:假设输入已经过唯一化。
invert:结果取反,如果 in0 的元素不在 in1 中则对应位输出 True。
ATen 实现的 isin 算法
我们想要用 Triton 实现一个标准算子来替换 pytorch 的 cuda 实现,理所应当先参考一下后者的实现算法。ATen 对 isin 算子提供了两种算法:
算法一(小尺寸):in0(示意图中的绿色方块代表其每个元素)和 in1(示意图中的粉色方块代表其每个元素)直接展平,两两比较是否相等,然后将结果归约为 in0.shape。
算法二(大尺寸):二者分别 unique 后,cat 在一起再排序一次,邻位判断相等的输出 True,注意结果需按 unique_order 进行 gather,以获取 unique 前的 in0 位序。
小尺寸实现
算法一(小尺寸):in0 和 in1 直接展平,对位比较,结果归约为 in0.shape。
我们可以将展平比较(pointwise)与归约(any/all)融合为一个核函数。比较是一个简单的逐点运算;规约可以参考 any 或 all 的算子实现,以此为框