Triton 面向的是数据块编程,屏蔽了大多数硬件细节,降低了开发门槛。开发人员可以专注于数据块划分和算法设计。通过合理的算法设计,Triton 实现的算子完全有可能在性能上超越 pytorch 中的 cuda 实现。

isin 算子的功能与接口 

def isin(in0, in1, *, assume_unique: bool = False, invert: bool = False) -> torch.Tensor

功能:判断 in0 的每个元素是否在 in1 出现过,返回一个和 in0 形状相同的布尔类型 Tensor。

assume_unique:假设输入已经过唯一化。

invert:结果取反,如果 in0 的元素不在 in1 中则对应位输出 True。

ATen 实现的 isin 算法

我们想要用 Triton 实现一个标准算子来替换 pytorch 的 cuda 实现,理所应当先参考一下后者的实现算法。ATen 对 isin 算子提供了两种算法:

算法一(小尺寸):in0(示意图中的绿色方块代表其每个元素)和 in1(示意图中的粉色方块代表其每个元素)直接展平,两两比较是否相等,然后将结果归约为 in0.shape。

算法二(大尺寸):二者分别 unique 后,cat 在一起再排序一次,邻位判断相等的输出 True,注意结果需按 unique_order 进行 gather,以获取 unique 前的 in0 位序。

小尺寸实现

算法一(小尺寸):in0 和 in1 直接展平,对位比较,结果归约为 in0.shape。

我们可以将展平比较(pointwise)与归约(any/all)融合为一个核函数。比较是一个简单的逐点运算;规约可以参考 any 或 all 的算子实现,以此为框架可以写出如下实现:

@triton.jit
def isin_by_comparation_impl(
    global_pid,
    in0_ravel_ptr: tl.tensor,
    in1_ravel_ptr: tl.tensor,  # in
    out_ptr: tl.tensor,  # out
    M: int,  # num_tasks
    N: int,  # num_tasks_1
    BLOCK_M: tl.constexpr,  # tile_size
    BLOCK_N: tl.constexpr,  # tile_size_1
    invert: tl.constexpr,
):
    row_off = global_pid * BLOCK_M
    rows = row_off + tl.arange(0, BLOCK_M)[:, None]
    row_mask = rows < M
    out_ptr += rows
    in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32)
    in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None]

    block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)
    in0 = tl.load(in0_ravel_ptr, row_mask, other=0)
    for col_off in range(0, N, BLOCK_N):
        cols = col_off + tl.arange(0, BLOCK_N)[None, :]
        col_mask = cols < N
        mask = row_mask and col_mask
        in1 = tl.load(in1_ravel_ptr + cols, mask, other=0)
        block = tl.where(
            mask,
            tl.where(invert, block and (in0 != in1), block or (in0 == in1)),
            invert,
        )
    out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any))
    tl.store(out_ptr, out[:, None], row_mask)

其中 X[:, None] 表示将一维张量 X 第 0 维保持不变,在第 1 维广播成一个二维张量。行之间采用并行调度;列之间在块内(BLOCK_N)采用并行调度,块间采用串行遍历的算法。每个行块初始化一个宽度为 BLOCK_M 的二维张量 block,然后在循环中沿着列遍历所有列,遍历的每一步更新 block 元素值的计算即为 in0 和 in1 展平比较是否相等的过程,遍历完成后将所有列得到的结果归约(axis=1)。

注意当 invert=False 时,使用或运算(any)规约;当 invert=True 时,使用与运算(all)规约。同时需要注意执行掩码 mask 由行列掩码共同决定,row_mask 是在行维度上限制不超过最大行数(M),col_mask 是在列维度上限制不超过最大列数(N)。

大尺寸算法改良及实现

算法二(大尺寸):二者分别 unique 后,cat 在一起再排序一次,邻位判断相等的输出 True。unique 的结果是生成唯一元素序列 unique_data,以及原张量 in0 在结果序列 unique_data 中的下标 unique_order。考虑到输入数据经 unique 后可能会变化顺序、减少数量,所以应注意 isin 的结果需按 unique_order 进行 gather,以获取 unique 前的 in0 位序。

很显然,ATen 提供的这个算法由多个耗时很高的算子组成,有很大的改进空间。我们可以改成对 in0 的每个元素,在已排序的 in1 中二分查找是否存在。通过 pytorch 的 unique/sort 算子和 searchsorted 算子,我们可以快速搭出一段代码来验证该改进算法的有效性。

进一步的,标准的 searchsorted 算子返回的是查找结果下标,但这对于 isin 算子是无用的,我们的目标只是查找元素是否在 in1 中存在。于是我们在 Triton 中的实现可以改进如下:

@triton.jit
def isin_by_search_impl(
    global_pid,
    in0_ravel_ptr: tl.tensor,
    in1_sorted_ptr: tl.tensor,  # in
    out_ptr: tl.tensor,  # out
    M: int,  # num_tasks
    N: int,  # num_tasks_1
    log_n: tl.constexpr,
    BLOCK_M: tl.constexpr,  # tile_size
    invert: tl.constexpr,
):
    r = tl.arange(0, BLOCK_M)
    i0 = global_pid * BLOCK_M + r
    mask = i0 < M

    # load in0_ravel
    in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)

    # binary search: lower_bound
    out = tl.zeros_like(r).to(tl.int1)
    start = tl.zeros_like(r)
    end = start + N
    while_mask = start < end
    for i in range(log_n):
        mid = tl.where(while_mask, start + (end - start) // 2, 0)
        mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)
        out = tl.where(while_mask, out or (mid_val == in0_ravel), out)  # found
        start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)
        end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)
        while_mask = start < end

    # store out
    tl.store(out_ptr + i0, not out if invert else out, mask=mask)

其中 for 循环内为二分查找过程。这里查找区间是半开半闭区间 [start, end),当然也可以实现成全闭区间的二分查找,效率差别不大,感兴趣的读者可以自行实现作比较。可以看到,与 CPU 上的经典二分查找实现相比,用 Triton 描述二分查找过程,只需要注意的是对 BLOCK_M 行的数据并发的在目标列上作二分查找,循环结束的条件有所不同,而计算逻辑描述是类似的。

性能测试结果

以下测试在 A100 上进行,输入数据类型以 int32 为例,batch_size = 1024。

测试的输入分成了两种情况,稀疏情形(sparse)输入数据的重复率较低,稠密情形(dense)输入数据的重复率较高。其中前两个点因加速比过高,未作于图上。与 pytorch 的原生实现相比,Triton 实现的 isin 算子性能大幅提升,加速比可以达到 1.7 甚至更高。

综上所述,使用 Triton 开发算子,大部分硬件细节被屏蔽,使得开发人员可以专注于算法设计,开发门槛显著降低。合适的算法设计完全有可能获得比 pytorch 的 cuda 实现更高的性能。

完整代码请见 FlagGems 算子库(已开源于 github:https://github.com/FlagOpen/FlagGems)。

图片

扫码回复“Triton”

加入Triton中文社区交流群

Logo

欢迎来到由智源人工智能研究院发起的Triton中文社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。

更多推荐