Triton 入门指南 | isin算子性能优化
Triton 面向的是数据块编程,屏蔽了大多数硬件细节,降低了开发门槛。开发人员可以专注于数据块划分和算法设计。通过合理的算法设计,Triton 实现的算子完全有可能在性能上超越 pytorch 中的 cuda 实现。
Triton 面向的是数据块编程,屏蔽了大多数硬件细节,降低了开发门槛。开发人员可以专注于数据块划分和算法设计。通过合理的算法设计,Triton 实现的算子完全有可能在性能上超越 pytorch 中的 cuda 实现。
isin 算子的功能与接口
def isin(in0, in1, *, assume_unique: bool = False, invert: bool = False) -> torch.Tensor
功能:判断 in0 的每个元素是否在 in1 出现过,返回一个和 in0 形状相同的布尔类型 Tensor。
assume_unique:假设输入已经过唯一化。
invert:结果取反,如果 in0 的元素不在 in1 中则对应位输出 True。
ATen 实现的 isin 算法
我们想要用 Triton 实现一个标准算子来替换 pytorch 的 cuda 实现,理所应当先参考一下后者的实现算法。ATen 对 isin 算子提供了两种算法:
算法一(小尺寸):in0(示意图中的绿色方块代表其每个元素)和 in1(示意图中的粉色方块代表其每个元素)直接展平,两两比较是否相等,然后将结果归约为 in0.shape。
算法二(大尺寸):二者分别 unique 后,cat 在一起再排序一次,邻位判断相等的输出 True,注意结果需按 unique_order 进行 gather,以获取 unique 前的 in0 位序。
小尺寸实现
算法一(小尺寸):in0 和 in1 直接展平,对位比较,结果归约为 in0.shape。
我们可以将展平比较(pointwise)与归约(any/all)融合为一个核函数。比较是一个简单的逐点运算;规约可以参考 any 或 all 的算子实现,以此为框架可以写出如下实现:
@triton.jit
def isin_by_comparation_impl(
global_pid,
in0_ravel_ptr: tl.tensor,
in1_ravel_ptr: tl.tensor, # in
out_ptr: tl.tensor, # out
M: int, # num_tasks
N: int, # num_tasks_1
BLOCK_M: tl.constexpr, # tile_size
BLOCK_N: tl.constexpr, # tile_size_1
invert: tl.constexpr,
):
row_off = global_pid * BLOCK_M
rows = row_off + tl.arange(0, BLOCK_M)[:, None]
row_mask = rows < M
out_ptr += rows
in0_ravel_ptr += rows + tl.zeros([BLOCK_N], dtype=tl.int32)
in1_ravel_ptr += tl.zeros([BLOCK_M], dtype=tl.int32)[:, None]
block = tl.full([BLOCK_M, BLOCK_N], value=(1 if invert else 0), dtype=tl.int1)
in0 = tl.load(in0_ravel_ptr, row_mask, other=0)
for col_off in range(0, N, BLOCK_N):
cols = col_off + tl.arange(0, BLOCK_N)[None, :]
col_mask = cols < N
mask = row_mask and col_mask
in1 = tl.load(in1_ravel_ptr + cols, mask, other=0)
block = tl.where(
mask,
tl.where(invert, block and (in0 != in1), block or (in0 == in1)),
invert,
)
out = tl.reduce(block, axis=1, combine_fn=(reduce_all if invert else reduce_any))
tl.store(out_ptr, out[:, None], row_mask)
其中 X[:, None]
表示将一维张量 X 第 0 维保持不变,在第 1 维广播成一个二维张量。行之间采用并行调度;列之间在块内(BLOCK_N)采用并行调度,块间采用串行遍历的算法。每个行块初始化一个宽度为 BLOCK_M 的二维张量 block,然后在循环中沿着列遍历所有列,遍历的每一步更新 block 元素值的计算即为 in0 和 in1 展平比较是否相等的过程,遍历完成后将所有列得到的结果归约(axis=1
)。
注意当 invert=False
时,使用或运算(any)规约;当 invert=True
时,使用与运算(all)规约。同时需要注意执行掩码 mask 由行列掩码共同决定,row_mask 是在行维度上限制不超过最大行数(M),col_mask 是在列维度上限制不超过最大列数(N)。
大尺寸算法改良及实现
算法二(大尺寸):二者分别 unique 后,cat 在一起再排序一次,邻位判断相等的输出 True。unique 的结果是生成唯一元素序列 unique_data,以及原张量 in0 在结果序列 unique_data 中的下标 unique_order。考虑到输入数据经 unique 后可能会变化顺序、减少数量,所以应注意 isin 的结果需按 unique_order 进行 gather,以获取 unique 前的 in0 位序。
很显然,ATen 提供的这个算法由多个耗时很高的算子组成,有很大的改进空间。我们可以改成对 in0 的每个元素,在已排序的 in1 中二分查找是否存在。通过 pytorch 的 unique/sort 算子和 searchsorted 算子,我们可以快速搭出一段代码来验证该改进算法的有效性。
进一步的,标准的 searchsorted 算子返回的是查找结果下标,但这对于 isin 算子是无用的,我们的目标只是查找元素是否在 in1 中存在。于是我们在 Triton 中的实现可以改进如下:
@triton.jit
def isin_by_search_impl(
global_pid,
in0_ravel_ptr: tl.tensor,
in1_sorted_ptr: tl.tensor, # in
out_ptr: tl.tensor, # out
M: int, # num_tasks
N: int, # num_tasks_1
log_n: tl.constexpr,
BLOCK_M: tl.constexpr, # tile_size
invert: tl.constexpr,
):
r = tl.arange(0, BLOCK_M)
i0 = global_pid * BLOCK_M + r
mask = i0 < M
# load in0_ravel
in0_ravel = tl.load(in0_ravel_ptr + i0, mask=mask)
# binary search: lower_bound
out = tl.zeros_like(r).to(tl.int1)
start = tl.zeros_like(r)
end = start + N
while_mask = start < end
for i in range(log_n):
mid = tl.where(while_mask, start + (end - start) // 2, 0)
mid_val = tl.load(in1_sorted_ptr + mid, mask=while_mask)
out = tl.where(while_mask, out or (mid_val == in0_ravel), out) # found
start = tl.where(while_mask and (mid_val < in0_ravel), mid + 1, start)
end = tl.where(while_mask and (mid_val > in0_ravel), mid, end)
while_mask = start < end
# store out
tl.store(out_ptr + i0, not out if invert else out, mask=mask)
其中 for 循环内为二分查找过程。这里查找区间是半开半闭区间 [start, end),当然也可以实现成全闭区间的二分查找,效率差别不大,感兴趣的读者可以自行实现作比较。可以看到,与 CPU 上的经典二分查找实现相比,用 Triton 描述二分查找过程,只需要注意的是对 BLOCK_M 行的数据并发的在目标列上作二分查找,循环结束的条件有所不同,而计算逻辑描述是类似的。
性能测试结果
以下测试在 A100 上进行,输入数据类型以 int32 为例,batch_size = 1024。
测试的输入分成了两种情况,稀疏情形(sparse)输入数据的重复率较低,稠密情形(dense)输入数据的重复率较高。其中前两个点因加速比过高,未作于图上。与 pytorch 的原生实现相比,Triton 实现的 isin 算子性能大幅提升,加速比可以达到 1.7 甚至更高。
综上所述,使用 Triton 开发算子,大部分硬件细节被屏蔽,使得开发人员可以专注于算法设计,开发门槛显著降低。合适的算法设计完全有可能获得比 pytorch 的 cuda 实现更高的性能。
完整代码请见 FlagGems 算子库(已开源于 github:https://github.com/FlagOpen/FlagGems)。
扫码回复“Triton”
加入Triton中文社区交流群
欢迎来到由智源人工智能研究院发起的Triton中文社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。
更多推荐
所有评论(0)